Skip to content

Commit 8b4e3dd

Browse files
committed
MLP and Hopfield trials
1 parent 8bc8043 commit 8b4e3dd

File tree

5 files changed

+321
-10
lines changed

5 files changed

+321
-10
lines changed

BitBoardGen/src/com/rahul/bbgen/Generator.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ public String prettyPrintBoards() {
145145

146146
public void buildBoards() {
147147
try {
148-
DataSetRow dataSetRow;
148+
MyDataSetRow dataSetRow;
149149
BufferedReader br = new BufferedReader(new FileReader(new File(
150150
"./data/game4.fen")));
151151
String fen;
@@ -182,7 +182,7 @@ private void openObjectStream() {
182182
}
183183
}
184184

185-
private void storeInputData(DataSetRow dataSetRow) {
185+
private void storeInputData(MyDataSetRow dataSetRow) {
186186

187187
try {
188188
oos.writeObject(dataSetRow);
@@ -191,8 +191,8 @@ private void storeInputData(DataSetRow dataSetRow) {
191191
}
192192
}
193193

194-
private DataSetRow buildInputData() {
195-
DataSetRow dataSetRow = new DataSetRow();
194+
private MyDataSetRow buildInputData() {
195+
MyDataSetRow dataSetRow = new MyDataSetRow();
196196
dataSetRow.setBBishops(BBishops.getBitBoard());
197197
dataSetRow.setBKing(BKing.getBitBoard());
198198
dataSetRow.setBKnights(BKnights.getBitBoard());
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
package com.rahul.bbgen;
2+
3+
import java.io.Serializable;
4+
import java.util.Arrays;
5+
6+
public class MyDataSetRow implements Serializable {
7+
8+
private static final long serialVersionUID = 1L;
9+
10+
private double[] WRooks;
11+
private double[] WBishops;
12+
private double[] WPawns;
13+
private double[] WKnights;
14+
private double[] WKing;
15+
private double[] WQueen;
16+
17+
private double[] BRooks;
18+
private double[] BBishops;
19+
private double[] BPawns;
20+
private double[] BKnights;
21+
private double[] BKing;
22+
private double[] BQueen;
23+
24+
private double[] evalScore;
25+
26+
public double[] getRow() {
27+
return concatAll(WRooks, WBishops, WPawns, WKnights, WKing, WQueen,
28+
BRooks, BBishops, BPawns, BKnights, BKing, BQueen);
29+
}
30+
31+
private double[] concatAll(double[] first, double[]... rest) {
32+
int totalLength = first.length;
33+
for (double[] array : rest) {
34+
totalLength += array.length;
35+
}
36+
double[] result = Arrays.copyOf(first, totalLength);
37+
int offset = first.length;
38+
for (double[] array : rest) {
39+
System.arraycopy(array, 0, result, offset, array.length);
40+
offset += array.length;
41+
}
42+
return result;
43+
}
44+
45+
public double[] getWRooks() {
46+
return WRooks;
47+
}
48+
49+
public void setWRooks(double[] wRooks) {
50+
WRooks = wRooks;
51+
}
52+
53+
public double[] getWBishops() {
54+
return WBishops;
55+
}
56+
57+
public void setWBishops(double[] wBishops) {
58+
WBishops = wBishops;
59+
}
60+
61+
public double[] getWPawns() {
62+
return WPawns;
63+
}
64+
65+
public void setWPawns(double[] wPawns) {
66+
WPawns = wPawns;
67+
}
68+
69+
public double[] getWKnights() {
70+
return WKnights;
71+
}
72+
73+
public void setWKnights(double[] wKnights) {
74+
WKnights = wKnights;
75+
}
76+
77+
public double[] getWKing() {
78+
return WKing;
79+
}
80+
81+
public void setWKing(double[] wKing) {
82+
WKing = wKing;
83+
}
84+
85+
public double[] getWQueen() {
86+
return WQueen;
87+
}
88+
89+
public void setWQueen(double[] wQueen) {
90+
WQueen = wQueen;
91+
}
92+
93+
public double[] getBRooks() {
94+
return BRooks;
95+
}
96+
97+
public void setBRooks(double[] bRooks) {
98+
BRooks = bRooks;
99+
}
100+
101+
public double[] getBBishops() {
102+
return BBishops;
103+
}
104+
105+
public void setBBishops(double[] bBishops) {
106+
BBishops = bBishops;
107+
}
108+
109+
public double[] getBPawns() {
110+
return BPawns;
111+
}
112+
113+
public void setBPawns(double[] bPawns) {
114+
BPawns = bPawns;
115+
}
116+
117+
public double[] getBKnights() {
118+
return BKnights;
119+
}
120+
121+
public void setBKnights(double[] bKnights) {
122+
BKnights = bKnights;
123+
}
124+
125+
public double[] getBKing() {
126+
return BKing;
127+
}
128+
129+
public void setBKing(double[] bKing) {
130+
BKing = bKing;
131+
}
132+
133+
public double[] getBQueen() {
134+
return BQueen;
135+
}
136+
137+
public void setBQueen(double[] bQueen) {
138+
BQueen = bQueen;
139+
}
140+
141+
public double[] getEvalScore() {
142+
return evalScore;
143+
}
144+
145+
public void setEvalScore(double[] evalScore) {
146+
this.evalScore = evalScore;
147+
}
148+
}
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
package com.rahul.trainer;
2+
3+
import java.util.ArrayList;
4+
import java.util.Arrays;
5+
6+
import org.neuroph.core.data.DataSet;
7+
import org.neuroph.core.data.DataSetRow;
8+
import org.neuroph.nnet.Hopfield;
9+
10+
import com.rahul.bbgen.BitBoard;
11+
import com.rahul.bbgen.MyDataSetRow;
12+
13+
public class HopfieldNetwork {
14+
15+
public static void main(String args[]) {
16+
17+
ArrayList<MyDataSetRow> data;
18+
19+
// create training set
20+
data = PrepareData.getData();
21+
DataSet trainingSet = new DataSet(BitBoard.BOARD_LENGTH * 12, 1);
22+
for (MyDataSetRow row : data) {
23+
trainingSet.addRow(row.getRow(), row.getEvalScore());
24+
}
25+
26+
// create hopfield network
27+
Hopfield myHopfield = new Hopfield(BitBoard.BOARD_LENGTH * 12);
28+
// learn the training set
29+
myHopfield.learn(trainingSet);
30+
31+
// test hopfield network
32+
System.out.println("Testing network");
33+
34+
// print network output for the each element from the specified training
35+
// set.
36+
int counter = 0;
37+
for (DataSetRow trainingSetRow : trainingSet.getRows()) {
38+
myHopfield.setInput(trainingSetRow.getInput());
39+
myHopfield.calculate();
40+
double[] networkOutput = myHopfield.getOutput();
41+
42+
printArray(trainingSetRow.getInput());
43+
printArray(networkOutput);
44+
printPairwiseError(trainingSetRow.getInput(), networkOutput);
45+
46+
counter++;
47+
}
48+
System.out.println("Total rows : " + counter);
49+
}
50+
51+
private static void printArray(double[] a) {
52+
for(int i=0;i<a.length; i++)
53+
System.out.printf("%.0f", a[i]);
54+
System.out.println();
55+
}
56+
57+
private static void printPairwiseError(double[] a, double[] b) {
58+
float error = 0f;
59+
for (int i = 0; i < a.length; i++) {
60+
if (a[i] != b[i])
61+
error++;
62+
}
63+
System.out.println(error / a.length);
64+
}
65+
}
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
package com.rahul.trainer;
2+
3+
import java.util.ArrayList;
4+
import java.util.Arrays;
5+
6+
import org.neuroph.core.NeuralNetwork;
7+
import org.neuroph.core.data.DataSet;
8+
import org.neuroph.core.data.DataSetRow;
9+
import org.neuroph.core.events.LearningEvent;
10+
import org.neuroph.core.events.LearningEventListener;
11+
import org.neuroph.core.learning.LearningRule;
12+
import org.neuroph.nnet.MultiLayerPerceptron;
13+
import org.neuroph.nnet.learning.BackPropagation;
14+
import org.neuroph.nnet.learning.MomentumBackpropagation;
15+
import org.neuroph.util.TransferFunctionType;
16+
17+
import com.rahul.bbgen.BitBoard;
18+
import com.rahul.bbgen.MyDataSetRow;
19+
20+
public class MLP implements LearningEventListener {
21+
22+
private ArrayList<MyDataSetRow> data;
23+
24+
public static void main(String[] args) {
25+
new MLP().run();
26+
}
27+
28+
public void run() {
29+
30+
// create training set
31+
data = PrepareData.getData();
32+
DataSet trainingSet = new DataSet(BitBoard.BOARD_LENGTH * 12, 1);
33+
for (MyDataSetRow row : data) {
34+
trainingSet.addRow(row.getRow(), row.getEvalScore());
35+
}
36+
37+
// trainingSet.normalize();
38+
39+
// create multi layer perceptron
40+
MultiLayerPerceptron myMlPerceptron = new MultiLayerPerceptron(
41+
TransferFunctionType.TANH, BitBoard.BOARD_LENGTH * 12,
42+
BitBoard.BOARD_LENGTH * 6, 1);
43+
44+
// enable batch if using MomentumBackpropagation
45+
if (myMlPerceptron.getLearningRule() instanceof MomentumBackpropagation)
46+
((MomentumBackpropagation) myMlPerceptron.getLearningRule())
47+
.setBatchMode(true);
48+
49+
LearningRule learningRule = myMlPerceptron.getLearningRule();
50+
learningRule.addListener(this);
51+
52+
// learn the training set
53+
System.out.println("Training neural network...");
54+
myMlPerceptron.learn(trainingSet);
55+
56+
// test perceptron
57+
System.out.println("Testing trained neural network");
58+
testNeuralNetwork(myMlPerceptron, trainingSet);
59+
60+
// save trained neural network
61+
myMlPerceptron.save("2Layer_try01.nnet");
62+
63+
// load saved neural network
64+
NeuralNetwork loadedMlPerceptron = NeuralNetwork
65+
.load("2Layer_try01.nnet");
66+
67+
// test loaded neural network
68+
System.out.println("Testing loaded neural network");
69+
testNeuralNetwork(loadedMlPerceptron, trainingSet);
70+
}
71+
72+
public static void testNeuralNetwork(NeuralNetwork neuralNet,
73+
DataSet testSet) {
74+
75+
for (DataSetRow testSetRow : testSet.getRows()) {
76+
neuralNet.setInput(testSetRow.getInput());
77+
neuralNet.calculate();
78+
double[] networkOutput = neuralNet.getOutput();
79+
80+
System.out
81+
.print("Input: " + Arrays.toString(testSetRow.getInput()));
82+
System.out.println(" Output: " + Arrays.toString(networkOutput));
83+
}
84+
}
85+
86+
@Override
87+
public void handleLearningEvent(LearningEvent event) {
88+
BackPropagation bp = (BackPropagation) event.getSource();
89+
System.out.println(bp.getCurrentIteration() + ". iteration : "
90+
+ bp.getTotalNetworkError());
91+
}
92+
93+
}

BitBoardGen/src/com/rahul/trainer/PrepareData.java

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,18 @@
44
import java.io.ObjectInputStream;
55
import java.util.ArrayList;
66

7-
import com.rahul.bbgen.DataSetRow;
7+
import com.rahul.bbgen.MyDataSetRow;
88

99
public class PrepareData {
1010

11-
public static ArrayList<DataSetRow> data = new ArrayList<DataSetRow>();
11+
private static ArrayList<MyDataSetRow> data = new ArrayList<MyDataSetRow>();
1212

1313
private static void readData() {
1414
ObjectInputStream ois;
15-
DataSetRow dataSetRow;
15+
MyDataSetRow dataSetRow;
1616
try {
1717
ois = new ObjectInputStream(new FileInputStream("./data/game4.bb"));
18-
while((dataSetRow = (DataSetRow) ois.readObject()) != null) {
18+
while((dataSetRow = (MyDataSetRow) ois.readObject()) != null) {
1919
data.add(dataSetRow);
2020
}
2121
ois.close();
@@ -25,9 +25,8 @@ private static void readData() {
2525

2626
private static void printData() {
2727
int i = 0;
28-
for(DataSetRow dataSetRow : data) {
28+
for(MyDataSetRow dataSetRow : data) {
2929
prettyPrint(dataSetRow.getRow());
30-
System.out.println(" " + dataSetRow.getEvalScore()[0]);
3130
i++;
3231
}
3332
System.out.println("Total rows : " + i);
@@ -36,6 +35,12 @@ private static void printData() {
3635
private static void prettyPrint(double[] row) {
3736
for(double d : row)
3837
System.out.printf("%.0f", d);
38+
System.out.println();
39+
}
40+
41+
public static ArrayList<MyDataSetRow> getData() {
42+
readData();
43+
return data;
3944
}
4045

4146
public static void main(String[] args) {

0 commit comments

Comments
 (0)