我用numpy实现了VIT,手写vision transformer, 可在树莓派上运行,在hugging face上训练模型保存参数成numpy格式,纯numpy实现
先复制一点知乎上的内容
按照上面的流程图,一个ViT block可以分为以下几个步骤
(1) patch embedding:例如输入图片大小为224x224,将图片分为固定大小的patch,patch大小为16x16,则每张图像会生成224x224/16x16=196个patch,即输入序列长度为196,每个patch维度16x16x3=768,线性投射层的维度为768xN (N=768),因此输入通过线性投射层之后的维度依然为196x768,即一共有196个token,每个token的维度是768。这里还需要加上一个特殊字符cls,因此最终的维度是197x768。到目前为止,已经通过patch embedding将一个视觉问题转化为了一个seq2seq问题
(2) positional encoding(standard learnable 1D position embeddings):ViT同样需要加入位置编码,位置编码可以理解为一张表,表一共有N行,N的大小和输入序列长度相同,每一行代表一个向量,向量的维度和输入序列embedding的维度相同(768)。注意位置编码的操作是sum,而不是concat。加入位置编码信息之后,维度依然是197x768
(3) LN/multi-head attention/LN:LN输出维度依然是197x768。多头自注意力时,先将输入映射到q,k,v,如果只有一个头,qkv的维度都是197x768,如果有12个头(768/12=64),则qkv的维度是197x64,一共有12组qkv,最后再将12组qkv的输出拼接起来,输出维度是197x768,然后在过一层LN,维度依然是197x768
(4) MLP:将维度放大再缩小回去,197x768放大为197x3072,再缩小变为197x768
一个block之后维度依然和输入相同,都是197x768,因此可以堆叠多个block。最后会将特殊字符cls对应的输出 Z0 作为encoder的最终输出 ,代表最终的image presentation(另一种做法是不加cls字符,对所有的tokens的输出做一个平均),如下图公式(4),后面接一个MLP进行图片分类
vit 的 numpy 实现代码,可以直接看懂各个部分的细节实现 ,和bert有一些不一样,除了embedding层不一样之外,还有模型结构有有些不同,主要是layer_normalization放在了attention层和feed_forword层之前,bert都是放在之后
import numpy as np
import os
from PIL import Image # 加载保存的模型数据
model_data = np.load('vit_model_params.npz')
for i in model_data:
# print(i)
print(i,model_data[i].shape) patch_embedding_weight = model_data["vit.embeddings.patch_embeddings.projection.weight"]
patch_embedding_bias = model_data["vit.embeddings.patch_embeddings.projection.bias"]
position_embeddings = model_data["vit.embeddings.position_embeddings"]
cls_token_embeddings = model_data["vit.embeddings.cls_token"] def patch_embedding(images):
# 卷积核大小
kernel_size = 16
return conv2d(images, patch_embedding_weight,patch_embedding_bias,stride=kernel_size) def position_embedding():
return position_embeddings def model_input(images): patch_embedded = np.transpose(patch_embedding(images).reshape([1,768,-1]), (0, 2, 1)) patch_embedded = np.concatenate([cls_token_embeddings,patch_embedded],axis=1) # position_ids = np.array(range(patch_embedded.shape[1])) # 位置id
# 位置嵌入矩阵,形状为 (max_position, embedding_size)
position_embedded = position_embedding() embedding_output = patch_embedded + position_embedded return embedding_output def softmax(x, axis=None):
# e_x = np.exp(x).astype(np.float32) #
e_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
sum_ex = np.sum(e_x, axis=axis,keepdims=True).astype(np.float32)
return e_x / sum_ex def conv2d(images,weight,bias,stride=1,padding=0):
# 卷积操作
N, C, H, W = images.shape
F, _, HH, WW = weight.shape
# 计算卷积后的输出尺寸
H_out = (H - HH + 2 * padding) // stride + 1
W_out = (W - WW + 2 * padding) // stride + 1
# 初始化卷积层输出
out = np.zeros((N, F, H_out, W_out))
# 执行卷积运算
for i in range(H_out):
for j in range(W_out):
# 提取当前卷积窗口
window = images[:, :, i * stride:i * stride + HH, j * stride:j * stride + WW]
# 执行卷积运算
out[:, :, i, j] = np.sum(window * weight, axis=(1, 2, 3)) + bias
# 输出结果
# print("卷积层输出尺寸:", out.shape)
return out def scaled_dot_product_attention(Q, K, V, mask=None):
d_k = Q.shape[-1]
scores = np.matmul(Q, K.transpose(0, 2, 1)) / np.sqrt(d_k)
if mask is not None:
scores = np.where(mask, scores, np.full_like(scores, -np.inf))
attention_weights = softmax(scores, axis=-1)
# print(attention_weights)
# print(np.sum(attention_weights,axis=-1))
output = np.matmul(attention_weights, V)
return output, attention_weights def multihead_attention(input, num_heads,W_Q,B_Q,W_K,B_K,W_V,B_V,W_O,B_O): q = np.matmul(input, W_Q.T)+B_Q
k = np.matmul(input, W_K.T)+B_K
v = np.matmul(input, W_V.T)+B_V # 分割输入为多个头
q = np.split(q, num_heads, axis=-1)
k = np.split(k, num_heads, axis=-1)
v = np.split(v, num_heads, axis=-1) outputs = []
for q_,k_,v_ in zip(q,k,v):
output, attention_weights = scaled_dot_product_attention(q_, k_, v_)
outputs.append(output)
outputs = np.concatenate(outputs, axis=-1)
outputs = np.matmul(outputs, W_O.T)+B_O
return outputs def layer_normalization(x, weight, bias, eps=1e-12):
mean = np.mean(x, axis=-1, keepdims=True)
variance = np.var(x, axis=-1, keepdims=True)
std = np.sqrt(variance + eps)
normalized_x = (x - mean) / std
output = weight * normalized_x + bias
return output def feed_forward_layer(inputs, weight, bias, activation='relu'):
linear_output = np.matmul(inputs,weight) + bias if activation == 'relu':
activated_output = np.maximum(0, linear_output) # ReLU激活函数
elif activation == 'gelu':
activated_output = 0.5 * linear_output * (1 + np.tanh(np.sqrt(2 / np.pi) * (linear_output + 0.044715 * np.power(linear_output, 3)))) # GELU激活函数 elif activation == "tanh" :
activated_output = np.tanh(linear_output)
else:
activated_output = linear_output # 无激活函数 return activated_output def residual_connection(inputs, residual):
# 残差连接
residual_output = inputs + residual
return residual_output def vit(input,num_heads=12): for i in range(12):
# 调用多头自注意力函数
W_Q = model_data['vit.encoder.layer.{}.attention.attention.query.weight'.format(i)]
B_Q = model_data['vit.encoder.layer.{}.attention.attention.query.bias'.format(i)]
W_K = model_data['vit.encoder.layer.{}.attention.attention.key.weight'.format(i)]
B_K = model_data['vit.encoder.layer.{}.attention.attention.key.bias'.format(i)]
W_V = model_data['vit.encoder.layer.{}.attention.attention.value.weight'.format(i)]
B_V = model_data['vit.encoder.layer.{}.attention.attention.value.bias'.format(i)]
W_O = model_data['vit.encoder.layer.{}.attention.output.dense.weight'.format(i)]
B_O = model_data['vit.encoder.layer.{}.attention.output.dense.bias'.format(i)]
intermediate_weight = model_data['vit.encoder.layer.{}.intermediate.dense.weight'.format(i)]
intermediate_bias = model_data['vit.encoder.layer.{}.intermediate.dense.bias'.format(i)]
dense_weight = model_data['vit.encoder.layer.{}.output.dense.weight'.format(i)]
dense_bias = model_data['vit.encoder.layer.{}.output.dense.bias'.format(i)]
LayerNorm_before_weight = model_data['vit.encoder.layer.{}.layernorm_before.weight'.format(i)]
LayerNorm_before_bias = model_data['vit.encoder.layer.{}.layernorm_before.bias'.format(i)]
LayerNorm_after_weight = model_data['vit.encoder.layer.{}.layernorm_after.weight'.format(i)]
LayerNorm_after_bias = model_data['vit.encoder.layer.{}.layernorm_after.bias'.format(i)] output = layer_normalization(input,LayerNorm_before_weight,LayerNorm_before_bias)
output = multihead_attention(output, num_heads,W_Q,B_Q,W_K,B_K,W_V,B_V,W_O,B_O)
output1 = residual_connection(input,output)
#这里和模型输出一致 output = layer_normalization(output1,LayerNorm_after_weight,LayerNorm_after_bias) #一致
output = feed_forward_layer(output, intermediate_weight.T, intermediate_bias, activation='gelu')
output = feed_forward_layer(output, dense_weight.T, dense_bias, activation='')
output2 = residual_connection(output1,output) input = output2 bert_pooler_dense_weight = model_data['vit.layernorm.weight']
bert_pooler_dense_bias = model_data['vit.layernorm.bias']
output = layer_normalization(output2[:,0],bert_pooler_dense_weight,bert_pooler_dense_bias ) #一致
classifier_weight = model_data['classifier.weight']
classifier_bias = model_data['classifier.bias']
output = feed_forward_layer(output,classifier_weight.T,classifier_bias,activation="" ) #一致
output = np.argmax(output,axis=-1)
return output folder_path = './cifar10' # 替换为图片所在的文件夹路径
def infer_images_in_folder(folder_path):
for file_name in os.listdir(folder_path):
file_path = os.path.join(folder_path, file_name)
if os.path.isfile(file_path) and file_name.endswith(('.jpg', '.jpeg', '.png')):
image = Image.open(file_path)
image = image.resize((224, 224))
label = file_name.split(".")[0].split("_")[1]
image = np.array(image)/255.0
image = np.transpose(image, (2, 0, 1))
image = np.expand_dims(image,axis=0)
print("file_path:",file_path,"img size:",image.shape,"label:",label)
input = model_input(image)
predicted_class = vit(input)
print('Predicted class:', predicted_class) if __name__ == "__main__": infer_images_in_folder(folder_path)
结果:
file_path: ./cifar10/8619_5.jpg img size: (1, 3, 224, 224) label: 5
Predicted class: [5]
file_path: ./cifar10/6042_6.jpg img size: (1, 3, 224, 224) label: 6
Predicted class: [6]
file_path: ./cifar10/6801_6.jpg img size: (1, 3, 224, 224) label: 6
Predicted class: [6]
file_path: ./cifar10/7946_1.jpg img size: (1, 3, 224, 224) label: 1
Predicted class: [1]
file_path: ./cifar10/6925_2.jpg img size: (1, 3, 224, 224) label: 2
Predicted class: [2]
file_path: ./cifar10/6007_6.jpg img size: (1, 3, 224, 224) label: 6
Predicted class: [6]
file_path: ./cifar10/7903_1.jpg img size: (1, 3, 224, 224) label: 1
Predicted class: [1]
file_path: ./cifar10/7064_5.jpg img size: (1, 3, 224, 224) label: 5
Predicted class: [5]
file_path: ./cifar10/2713_8.jpg img size: (1, 3, 224, 224) label: 8
Predicted class: [8]
file_path: ./cifar10/8575_9.jpg img size: (1, 3, 224, 224) label: 9
Predicted class: [9]
file_path: ./cifar10/1985_6.jpg img size: (1, 3, 224, 224) label: 6
Predicted class: [6]
file_path: ./cifar10/5312_5.jpg img size: (1, 3, 224, 224) label: 5
Predicted class: [5]
file_path: ./cifar10/593_6.jpg img size: (1, 3, 224, 224) label: 6
Predicted class: [6]
file_path: ./cifar10/8093_7.jpg img size: (1, 3, 224, 224) label: 7
Predicted class: [7]
file_path: ./cifar10/6862_5.jpg img size: (1, 3, 224, 224) label: 5
模型参数:
vit.embeddings.cls_token (1, 1, 768)
vit.embeddings.position_embeddings (1, 197, 768)
vit.embeddings.patch_embeddings.projection.weight (768, 3, 16, 16)
vit.embeddings.patch_embeddings.projection.bias (768,)
vit.encoder.layer.0.attention.attention.query.weight (768, 768)
vit.encoder.layer.0.attention.attention.query.bias (768,)
vit.encoder.layer.0.attention.attention.key.weight (768, 768)
vit.encoder.layer.0.attention.attention.key.bias (768,)
vit.encoder.layer.0.attention.attention.value.weight (768, 768)
vit.encoder.layer.0.attention.attention.value.bias (768,)
vit.encoder.layer.0.attention.output.dense.weight (768, 768)
vit.encoder.layer.0.attention.output.dense.bias (768,)
vit.encoder.layer.0.intermediate.dense.weight (3072, 768)
vit.encoder.layer.0.intermediate.dense.bias (3072,)
vit.encoder.layer.0.output.dense.weight (768, 3072)
vit.encoder.layer.0.output.dense.bias (768,)
vit.encoder.layer.0.layernorm_before.weight (768,)
vit.encoder.layer.0.layernorm_before.bias (768,)
vit.encoder.layer.0.layernorm_after.weight (768,)
vit.encoder.layer.0.layernorm_after.bias (768,)
vit.encoder.layer.1.attention.attention.query.weight (768, 768)
vit.encoder.layer.1.attention.attention.query.bias (768,)
vit.encoder.layer.1.attention.attention.key.weight (768, 768)
vit.encoder.layer.1.attention.attention.key.bias (768,)
vit.encoder.layer.1.attention.attention.value.weight (768, 768)
vit.encoder.layer.1.attention.attention.value.bias (768,)
vit.encoder.layer.1.attention.output.dense.weight (768, 768)
vit.encoder.layer.1.attention.output.dense.bias (768,)
vit.encoder.layer.1.intermediate.dense.weight (3072, 768)
vit.encoder.layer.1.intermediate.dense.bias (3072,)
vit.encoder.layer.1.output.dense.weight (768, 3072)
vit.encoder.layer.1.output.dense.bias (768,)
vit.encoder.layer.1.layernorm_before.weight (768,)
vit.encoder.layer.1.layernorm_before.bias (768,)
vit.encoder.layer.1.layernorm_after.weight (768,)
vit.encoder.layer.1.layernorm_after.bias (768,)
vit.encoder.layer.2.attention.attention.query.weight (768, 768)
vit.encoder.layer.2.attention.attention.query.bias (768,)
vit.encoder.layer.2.attention.attention.key.weight (768, 768)
vit.encoder.layer.2.attention.attention.key.bias (768,)
vit.encoder.layer.2.attention.attention.value.weight (768, 768)
vit.encoder.layer.2.attention.attention.value.bias (768,)
vit.encoder.layer.2.attention.output.dense.weight (768, 768)
vit.encoder.layer.2.attention.output.dense.bias (768,)
vit.encoder.layer.2.intermediate.dense.weight (3072, 768)
vit.encoder.layer.2.intermediate.dense.bias (3072,)
vit.encoder.layer.2.output.dense.weight (768, 3072)
vit.encoder.layer.2.output.dense.bias (768,)
vit.encoder.layer.2.layernorm_before.weight (768,)
vit.encoder.layer.2.layernorm_before.bias (768,)
vit.encoder.layer.2.layernorm_after.weight (768,)
vit.encoder.layer.2.layernorm_after.bias (768,)
vit.encoder.layer.3.attention.attention.query.weight (768, 768)
vit.encoder.layer.3.attention.attention.query.bias (768,)
vit.encoder.layer.3.attention.attention.key.weight (768, 768)
vit.encoder.layer.3.attention.attention.key.bias (768,)
vit.encoder.layer.3.attention.attention.value.weight (768, 768)
vit.encoder.layer.3.attention.attention.value.bias (768,)
vit.encoder.layer.3.attention.output.dense.weight (768, 768)
vit.encoder.layer.3.attention.output.dense.bias (768,)
vit.encoder.layer.3.intermediate.dense.weight (3072, 768)
vit.encoder.layer.3.intermediate.dense.bias (3072,)
vit.encoder.layer.3.output.dense.weight (768, 3072)
vit.encoder.layer.3.output.dense.bias (768,)
vit.encoder.layer.3.layernorm_before.weight (768,)
vit.encoder.layer.3.layernorm_before.bias (768,)
vit.encoder.layer.3.layernorm_after.weight (768,)
vit.encoder.layer.3.layernorm_after.bias (768,)
vit.encoder.layer.4.attention.attention.query.weight (768, 768)
vit.encoder.layer.4.attention.attention.query.bias (768,)
vit.encoder.layer.4.attention.attention.key.weight (768, 768)
vit.encoder.layer.4.attention.attention.key.bias (768,)
vit.encoder.layer.4.attention.attention.value.weight (768, 768)
vit.encoder.layer.4.attention.attention.value.bias (768,)
vit.encoder.layer.4.attention.output.dense.weight (768, 768)
vit.encoder.layer.4.attention.output.dense.bias (768,)
vit.encoder.layer.4.intermediate.dense.weight (3072, 768)
vit.encoder.layer.4.intermediate.dense.bias (3072,)
vit.encoder.layer.4.output.dense.weight (768, 3072)
vit.encoder.layer.4.output.dense.bias (768,)
vit.encoder.layer.4.layernorm_before.weight (768,)
vit.encoder.layer.4.layernorm_before.bias (768,)
vit.encoder.layer.4.layernorm_after.weight (768,)
vit.encoder.layer.4.layernorm_after.bias (768,)
vit.encoder.layer.5.attention.attention.query.weight (768, 768)
vit.encoder.layer.5.attention.attention.query.bias (768,)
vit.encoder.layer.5.attention.attention.key.weight (768, 768)
vit.encoder.layer.5.attention.attention.key.bias (768,)
vit.encoder.layer.5.attention.attention.value.weight (768, 768)
vit.encoder.layer.5.attention.attention.value.bias (768,)
vit.encoder.layer.5.attention.output.dense.weight (768, 768)
vit.encoder.layer.5.attention.output.dense.bias (768,)
vit.encoder.layer.5.intermediate.dense.weight (3072, 768)
vit.encoder.layer.5.intermediate.dense.bias (3072,)
vit.encoder.layer.5.output.dense.weight (768, 3072)
vit.encoder.layer.5.output.dense.bias (768,)
vit.encoder.layer.5.layernorm_before.weight (768,)
vit.encoder.layer.5.layernorm_before.bias (768,)
vit.encoder.layer.5.layernorm_after.weight (768,)
vit.encoder.layer.5.layernorm_after.bias (768,)
vit.encoder.layer.6.attention.attention.query.weight (768, 768)
vit.encoder.layer.6.attention.attention.query.bias (768,)
vit.encoder.layer.6.attention.attention.key.weight (768, 768)
vit.encoder.layer.6.attention.attention.key.bias (768,)
vit.encoder.layer.6.attention.attention.value.weight (768, 768)
vit.encoder.layer.6.attention.attention.value.bias (768,)
vit.encoder.layer.6.attention.output.dense.weight (768, 768)
vit.encoder.layer.6.attention.output.dense.bias (768,)
vit.encoder.layer.6.intermediate.dense.weight (3072, 768)
vit.encoder.layer.6.intermediate.dense.bias (3072,)
vit.encoder.layer.6.output.dense.weight (768, 3072)
vit.encoder.layer.6.output.dense.bias (768,)
vit.encoder.layer.6.layernorm_before.weight (768,)
vit.encoder.layer.6.layernorm_before.bias (768,)
vit.encoder.layer.6.layernorm_after.weight (768,)
vit.encoder.layer.6.layernorm_after.bias (768,)
vit.encoder.layer.7.attention.attention.query.weight (768, 768)
vit.encoder.layer.7.attention.attention.query.bias (768,)
vit.encoder.layer.7.attention.attention.key.weight (768, 768)
vit.encoder.layer.7.attention.attention.key.bias (768,)
vit.encoder.layer.7.attention.attention.value.weight (768, 768)
vit.encoder.layer.7.attention.attention.value.bias (768,)
vit.encoder.layer.7.attention.output.dense.weight (768, 768)
vit.encoder.layer.7.attention.output.dense.bias (768,)
vit.encoder.layer.7.intermediate.dense.weight (3072, 768)
vit.encoder.layer.7.intermediate.dense.bias (3072,)
vit.encoder.layer.7.output.dense.weight (768, 3072)
vit.encoder.layer.7.output.dense.bias (768,)
vit.encoder.layer.7.layernorm_before.weight (768,)
vit.encoder.layer.7.layernorm_before.bias (768,)
vit.encoder.layer.7.layernorm_after.weight (768,)
vit.encoder.layer.7.layernorm_after.bias (768,)
vit.encoder.layer.8.attention.attention.query.weight (768, 768)
vit.encoder.layer.8.attention.attention.query.bias (768,)
vit.encoder.layer.8.attention.attention.key.weight (768, 768)
vit.encoder.layer.8.attention.attention.key.bias (768,)
vit.encoder.layer.8.attention.attention.value.weight (768, 768)
vit.encoder.layer.8.attention.attention.value.bias (768,)
vit.encoder.layer.8.attention.output.dense.weight (768, 768)
vit.encoder.layer.8.attention.output.dense.bias (768,)
vit.encoder.layer.8.intermediate.dense.weight (3072, 768)
vit.encoder.layer.8.intermediate.dense.bias (3072,)
vit.encoder.layer.8.output.dense.weight (768, 3072)
vit.encoder.layer.8.output.dense.bias (768,)
vit.encoder.layer.8.layernorm_before.weight (768,)
vit.encoder.layer.8.layernorm_before.bias (768,)
vit.encoder.layer.8.layernorm_after.weight (768,)
vit.encoder.layer.8.layernorm_after.bias (768,)
vit.encoder.layer.9.attention.attention.query.weight (768, 768)
vit.encoder.layer.9.attention.attention.query.bias (768,)
vit.encoder.layer.9.attention.attention.key.weight (768, 768)
vit.encoder.layer.9.attention.attention.key.bias (768,)
vit.encoder.layer.9.attention.attention.value.weight (768, 768)
vit.encoder.layer.9.attention.attention.value.bias (768,)
vit.encoder.layer.9.attention.output.dense.weight (768, 768)
vit.encoder.layer.9.attention.output.dense.bias (768,)
vit.encoder.layer.9.intermediate.dense.weight (3072, 768)
vit.encoder.layer.9.intermediate.dense.bias (3072,)
vit.encoder.layer.9.output.dense.weight (768, 3072)
vit.encoder.layer.9.output.dense.bias (768,)
vit.encoder.layer.9.layernorm_before.weight (768,)
vit.encoder.layer.9.layernorm_before.bias (768,)
vit.encoder.layer.9.layernorm_after.weight (768,)
vit.encoder.layer.9.layernorm_after.bias (768,)
vit.encoder.layer.10.attention.attention.query.weight (768, 768)
vit.encoder.layer.10.attention.attention.query.bias (768,)
vit.encoder.layer.10.attention.attention.key.weight (768, 768)
vit.encoder.layer.10.attention.attention.key.bias (768,)
vit.encoder.layer.10.attention.attention.value.weight (768, 768)
vit.encoder.layer.10.attention.attention.value.bias (768,)
vit.encoder.layer.10.attention.output.dense.weight (768, 768)
vit.encoder.layer.10.attention.output.dense.bias (768,)
vit.encoder.layer.10.intermediate.dense.weight (3072, 768)
vit.encoder.layer.10.intermediate.dense.bias (3072,)
vit.encoder.layer.10.output.dense.weight (768, 3072)
vit.encoder.layer.10.output.dense.bias (768,)
vit.encoder.layer.10.layernorm_before.weight (768,)
vit.encoder.layer.10.layernorm_before.bias (768,)
vit.encoder.layer.10.layernorm_after.weight (768,)
vit.encoder.layer.10.layernorm_after.bias (768,)
vit.encoder.layer.11.attention.attention.query.weight (768, 768)
vit.encoder.layer.11.attention.attention.query.bias (768,)
vit.encoder.layer.11.attention.attention.key.weight (768, 768)
vit.encoder.layer.11.attention.attention.key.bias (768,)
vit.encoder.layer.11.attention.attention.value.weight (768, 768)
vit.encoder.layer.11.attention.attention.value.bias (768,)
vit.encoder.layer.11.attention.output.dense.weight (768, 768)
vit.encoder.layer.11.attention.output.dense.bias (768,)
vit.encoder.layer.11.intermediate.dense.weight (3072, 768)
vit.encoder.layer.11.intermediate.dense.bias (3072,)
vit.encoder.layer.11.output.dense.weight (768, 3072)
vit.encoder.layer.11.output.dense.bias (768,)
vit.encoder.layer.11.layernorm_before.weight (768,)
vit.encoder.layer.11.layernorm_before.bias (768,)
vit.encoder.layer.11.layernorm_after.weight (768,)
vit.encoder.layer.11.layernorm_after.bias (768,)
vit.layernorm.weight (768,)
vit.layernorm.bias (768,)
classifier.weight (10, 768)
classifier.bias (10,)
hungging face模型训练代码 对cifar10训练,保存模型参数为numpy格式,方便numpy的模型加载:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from transformers import ViTModel, ViTForImageClassification
from tqdm import tqdm
import numpy as np # 设置随机种子
torch.manual_seed(42) # 定义超参数
batch_size = 64
num_epochs = 1
learning_rate = 1e-4
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 数据预处理
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
]) # 加载CIFAR-10数据集
train_dataset = CIFAR10(root='/data/xinyuuliu/datas', train=True, download=True, transform=transform)
test_dataset = CIFAR10(root='/data/xinyuuliu/datas', train=False, download=True, transform=transform) # 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) # 加载预训练的ViT模型
vit_model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224').to(device) # 替换分类头
num_classes = 10
# vit_model.config.classifier = 'mlp'
# vit_model.config.num_labels = num_classes
vit_model.classifier = nn.Linear(vit_model.config.hidden_size, num_classes).to(device) # parameters = list(vit_model.parameters())
# for x in parameters[:-1]:
# x.requires_grad = False # 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(vit_model.parameters(), lr=learning_rate) # 微调ViT模型
for epoch in range(num_epochs):
print("epoch:",epoch)
vit_model.train()
train_loss = 0.0
train_correct = 0 bar = tqdm(train_loader,total=len(train_loader))
for images, labels in bar:
images = images.to(device)
labels = labels.to(device) # 前向传播
outputs = vit_model(images)
loss = criterion(outputs.logits, labels) # 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step() train_loss += loss.item()
_, predicted = torch.max(outputs.logits, 1)
train_correct += (predicted == labels).sum().item() # 在训练集上计算准确率
train_accuracy = 100.0 * train_correct / len(train_dataset) # 在测试集上进行评估
vit_model.eval()
test_loss = 0.0
test_correct = 0 with torch.no_grad():
bar = tqdm(test_loader,total=len(test_loader))
for images, labels in bar:
images = images.to(device)
labels = labels.to(device) outputs = vit_model(images)
loss = criterion(outputs.logits, labels) test_loss += loss.item()
_, predicted = torch.max(outputs.logits, 1)
test_correct += (predicted == labels).sum().item() # 在测试集上计算准确率
test_accuracy = 100.0 * test_correct / len(test_dataset) # 打印每个epoch的训练损失、训练准确率和测试准确率
print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%, Test Accuracy: {test_accuracy:.2f}%') torch.save(vit_model.state_dict(), 'vit_model_parameters.pth') # 打印BERT模型的权重维度
for name, param in vit_model.named_parameters():
print(name, param.data.shape) # # # 保存模型参数为NumPy格式
model_params = {name: param.data.cpu().numpy() for name, param in vit_model.named_parameters()}
np.savez('vit_model_params.npz', **model_params)
# model_params
Epoch [1/1], Train Loss: 97.7498, Train Accuracy: 96.21%, Test Accuracy: 96.86%
我用numpy实现了VIT,手写vision transformer, 可在树莓派上运行,在hugging face上训练模型保存参数成numpy格式,纯numpy实现的更多相关文章
- 利用sklearn对MNIST手写数据集开始一个简单的二分类判别器项目(在这个过程中学习关于模型性能的评价指标,如accuracy,precision,recall,混淆矩阵)
.caret, .dropup > .btn > .caret { border-top-color: #000 !important; } .label { border: 1px so ...
- Tensorflow之基于MNIST手写识别的入门介绍
Tensorflow是当下AI热潮下,最为受欢迎的开源框架.无论是从Github上的fork数量还是star数量,还是从支持的语音,开发资料,社区活跃度等多方面,他当之为superstar. 在前面介 ...
- GAN实战笔记——第三章第一个GAN模型:生成手写数字
第一个GAN模型-生成手写数字 一.GAN的基础:对抗训练 形式上,生成器和判别器由可微函数表示如神经网络,他们都有自己的代价函数.这两个网络是利用判别器的损失记性反向传播训练.判别器努力使真实样本输 ...
- 【TensorFlow篇】--Tensorflow框架实现SoftMax模型识别手写数字集
一.前述 本文讲述用Tensorflow框架实现SoftMax模型识别手写数字集,来实现多分类. 同时对模型的保存和恢复做下示例. 二.具体原理 代码一:实现代码 #!/usr/bin/python ...
- OpenCV+TensorFlow实现自定义手写图像识别
完整版请点击链接:https://mp.weixin.qq.com/s/5gHXGmLbtO7m3dOFrDUiHQ 或微信关注“大数据技术宅” 继用TensorFlow教你做手写字识别(准确率 ...
- 手写AVL 树(下)
上一篇 手写AVL树上实现了AVL树的插入和查询 上代码: 头文件:AVL.h #include <iostream> template<typename T1,typename T ...
- mnist 手写数字识别
mnist 手写数字识别三大步骤 1.定义分类模型2.训练模型3.评价模型 import tensorflow as tfimport input_datamnist = input_data.rea ...
- 手写Json转换
在做项目的时候总是要手动将集合转换成json每次都很麻烦,于是就尝试着写了一个公用的方法,用于转换List to json: using System; using System.Collection ...
- 全命令行手写MapReduce并且打包运行
主要要讲的有3个 java中的package是干啥的? 工作了好几年的都一定真正理解java里面的package关键字,这里在写MapReduce需要进行打包的时候突然发现命令行下打包运行居然不会了, ...
- 手写迷你SpringMVC框架
前言 学习如何使用Spring,SpringMVC是很快的,但是在往后使用的过程中难免会想探究一下框架背后的原理是什么,本文将通过讲解如何手写一个简单版的springMVC框架,直接从代码上看框架中请 ...
随机推荐
- 【LeetCode】3.19 对称二叉树
101. 对称二叉树 给你一个二叉树的根节点 root , 检查它是否轴对称. 示例 1: 输入:root = [1,2,2,3,4,4,3] 输出:true 示例 2: 输入:root = [1 ...
- Rust中的智能指针:Box<T> Rc<T> Arc<T> Cell<T> RefCell<T> Weak<T>
Rust中的智能指针是什么 智能指针(smart pointers)是一类数据结构,是拥有数据所有权和额外功能的指针.是指针的进一步发展 指针(pointer)是一个包含内存地址的变量的通用概念.这个 ...
- spring boot过滤器实现项目内接口过滤
spring boot过滤器实现项目内接口过滤 业务 由于业务需求,存在两套项目,一套是路由中心,一套是业务系统. 现在存在问题是,路由中心集成了微信公众号与小程序模块功能,业务系统部署了多套服务. ...
- pinia的使用
1. pinia和vuex的区别 pinia没有mutations,只有:state. getters. actions pinia分模块不需要modules(之前vuex分模块需要modules) ...
- mosn基于延迟负载均衡算法 -- 走得更快,期待走得更稳
前言 这篇文章主要是介绍mosn在v1.5.0中新引入的基于延迟的负载均衡算法. 对分布式系统中延迟出现的原因进行剖析 介绍mosn都通过哪些方法来降低延迟 构建来与生产环境性能分布相近的测试用例来对 ...
- 【Python基础】字典的基本使用
字典是由一系列键值对组成的无序集合.每个键值对包含一个键和一个对应的值.键必须是不可变的,如字符串.数字或元组.值可以是任意类型的对象.字典可以使用花括号({})或者内置函数dict()来创建. di ...
- 2023-01-04:有三个题库A、B、C,每个题库均有n道题目,且题目都是从1到n进行编号 每个题目都有一个难度值 题库A中第i个题目的难度为ai 题库B中第i个题目的难度为bi 题库C中第i个题目
2023-01-04:有三个题库A.B.C,每个题库均有n道题目,且题目都是从1到n进行编号 每个题目都有一个难度值 题库A中第i个题目的难度为ai 题库B中第i个题目的难度为bi 题库C中第i个题目 ...
- 2021-01-02:java中,MinorGC、MajorGC、FullGC 什么时候发生?
福哥答案2021-01-02: MinorGC 在年轻代空间不足的时候发生.MajorGC 指的是老年代的 GC,出现 MajorGC 一般经常伴有 MinorGC.FullGC 老年代无法再分配内存 ...
- uni-app 选择原因
开发者.案例数量更多跨平台能力及扩展灵活性更强性能体验优秀周边生态丰富学习成本低开发成本低
- 谷歌语法Github及利用方式
0x01简介 GoogleHack(谷歌语法)是指使用Google等搜索引擎对某些特定的网络主机漏洞(通常是服务器上的脚本漏洞)进行搜索,以达到快速找到漏洞主机或特定主机的漏洞的目的.比如使用搜索包含 ...