/// <summary>
/// 采用神经网络处理Fashion-MNIST数据集
/// </summary>
public class NN_MultipleClassification_Fashion_MNIST
private readonly string TrainImagePath = @"D:\Study\Blogs\TF_Net\Asset\fashion_mnist_png\train";
private readonly string TestImagePath = @"D:\Study\Blogs\TF_Net\Asset\fashion_mnist_png\test";
private readonly string train_date_path = @"D:\Study\Blogs\TF_Net\Asset\fashion_mnist_png\train_data.bin";
private readonly string train_label_path = @"D:\Study\Blogs\TF_Net\Asset\fashion_mnist_png\train_label.bin"; private readonly int img_rows = 28;
private readonly int img_cols = 28;
private readonly int num_classes = 10; // total classes public void Run()
var model = BuildModel();
model.summary(); model.compile(optimizer: keras.optimizers.Adam(0.001f),
loss: keras.losses.SparseCategoricalCrossentropy(),
metrics: new[] { "accuracy" }); (NDArray train_x, NDArray train_y) = LoadTrainingData();
model.fit(train_x, train_y, batch_size: 1024, epochs: 20); test(model);
} /// <summary>
/// 构建网络模型
/// </summary>
private Model BuildModel()
// 网络参数
int n_hidden_1 = 128; // 1st layer number of neurons.
int n_hidden_2 = 128; // 2nd layer number of neurons.
float scale = 1.0f / 255; var model = keras.Sequential(new List<ILayer>
keras.layers.Dense(n_hidden_1, activation:keras.activations.Relu),
keras.layers.Dense(n_hidden_2, activation:keras.activations.Relu),
keras.layers.Dense(num_classes, activation:keras.activations.Softmax)
}); return model;
} /// <summary>
/// 加载训练数据
/// </summary>
/// <param name="total_size"></param>
private (NDArray, NDArray) LoadTrainingData()
Console.WriteLine("Load data");
IFormatter serializer = new BinaryFormatter();
FileStream loadFile = new FileStream(train_date_path, FileMode.Open, FileAccess.Read);
float[,,] arrx = serializer.Deserialize(loadFile) as float[,,]; loadFile = new FileStream(train_label_path, FileMode.Open, FileAccess.Read);
int[] arry = serializer.Deserialize(loadFile) as int[];
Console.WriteLine("Load data success");
return (np.array(arrx), np.array(arry));
catch (Exception ex)
Console.WriteLine($"Load data Exception:{ex.Message}");
return LoadRawData();
} private (NDArray, NDArray) LoadRawData()
Console.WriteLine("LoadRawData"); int total_size = 60000;
float[,,] arrx = new float[total_size, img_rows, img_cols];
int[] arry = new int[total_size]; int count = 0; DirectoryInfo RootDir = new DirectoryInfo(TrainImagePath);
foreach (var Dir in RootDir.GetDirectories())
foreach (var file in Dir.GetFiles("*.png"))
Bitmap bmp = (Bitmap)Image.FromFile(file.FullName);
if (bmp.Width != img_cols || bmp.Height != img_rows)
} for (int row = 0; row < img_rows; row++)
for (int col = 0; col < img_cols; col++)
var pixel = bmp.GetPixel(col, row);
int val = (pixel.R + pixel.G + pixel.B) / 3; arrx[count, row, col] = val;
arry[count] = int.Parse(Dir.Name);
} count++;
} Console.WriteLine($"Load image data count={count}");
} Console.WriteLine("LoadRawData finished");
//Save Data
Console.WriteLine("Save data");
IFormatter serializer = new BinaryFormatter(); //开始序列化
FileStream saveFile = new FileStream(train_date_path, FileMode.Create, FileAccess.Write);
serializer.Serialize(saveFile, arrx);
saveFile.Close(); saveFile = new FileStream(train_label_path, FileMode.Create, FileAccess.Write);
serializer.Serialize(saveFile, arry);
Console.WriteLine("Save data finished"); return (np.array(arrx), np.array(arry));
} /// <summary>
/// 消费模型
/// </summary>
private void test(Model model)
Random rand = new Random(1); DirectoryInfo TestDir = new DirectoryInfo(TestImagePath);
foreach (var ChildDir in TestDir.GetDirectories())
var Files = ChildDir.GetFiles("*.png");
for (int i = 0; i < 10; i++)
int index = rand.Next(1000);
var image = Files[index]; var x = LoadImage(image.FullName);
var pred_y = model.Apply(x);
var result = argmax(pred_y[0].numpy()); Console.WriteLine($"FileName:{image.Name}\tPred:{result}");
} private NDArray LoadImage(string filename)
float[,,] arrx = new float[1, img_rows, img_cols];
Bitmap bmp = (Bitmap)Image.FromFile(filename); for (int row = 0; row < img_rows; row++)
for (int col = 0; col < img_cols; col++)
var pixel = bmp.GetPixel(col, row);
int val = (pixel.R + pixel.G + pixel.B) / 3;
arrx[0, row, col] = val;
} return np.array(arrx);
} private int argmax(NDArray array)
var arr = array.reshape(-1); float max = 0;
for (int i = 0; i < 10; i++)
if (arr[i] > max)
max = arr[i];
} for (int i = 0; i < 10; i++)
if (arr[i] == max)
return i;
} return 0;
源码:Git: https://gitee.com/seabluescn/tf_not.git
