✨ feat(client): 新增基于Syft的CIFAR10 CNN训练客户端脚本 - 定义简单卷积神经网络模型SimpleCNN - 实现train_cnn_model函数进行模型训练 - 通过Syft客户端连接远程服务器并获取数据集 - 创建研究项目和代码请求,执行远程训练任务 - 包含异常处理和训练状态打印 ✨ feat(server): 新增基于Syft的CIFAR10数据服务器脚本 - 启动Syft数据服务器并登录默认账户 - 设置数据所有者账户信息并重新登录 - 下载并加载CIFAR10训练数据集,创建数据加载器 - 生成模拟数据并创建对应资产 - 构建数据集元数据并上传至服务器 - 创建研究者账户供客户端使用 - 服务器保持运行直到手动停止
118 lines
3.5 KiB
Python
118 lines
3.5 KiB
Python
import syft as sy
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
|
|
class SimpleCNN(nn.Module):
|
|
def __init__(self):
|
|
super(SimpleCNN, self).__init__()
|
|
self.conv1 = nn.Conv2d(3, 6, 5)
|
|
self.pool = nn.MaxPool2d(2, 2)
|
|
self.conv2 = nn.Conv2d(6, 16, 5)
|
|
self.fc1 = nn.Linear(16 * 5 * 5, 120)
|
|
self.fc2 = nn.Linear(120, 84)
|
|
self.fc3 = nn.Linear(84, 10)
|
|
|
|
def forward(self, x):
|
|
x = self.pool(torch.relu(self.conv1(x)))
|
|
x = self.pool(torch.relu(self.conv2(x)))
|
|
x = x.view(-1, 16 * 5 * 5)
|
|
x = torch.relu(self.fc1(x))
|
|
x = torch.relu(self.fc2(x))
|
|
x = self.fc3(x)
|
|
return x
|
|
|
|
def train_cnn_model(images, labels):
|
|
# 创建模型
|
|
model = SimpleCNN()
|
|
criterion = nn.CrossEntropyLoss()
|
|
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
|
|
|
|
# 训练模型
|
|
print("Starting training...")
|
|
for epoch in range(2): # 只训练2个epoch作为示例
|
|
running_loss = 0.0
|
|
for i in range(len(images)):
|
|
# 获取数据
|
|
inputs = images[i].unsqueeze(0)
|
|
target = labels[i].unsqueeze(0)
|
|
|
|
# 前向传播
|
|
optimizer.zero_grad()
|
|
outputs = model(inputs)
|
|
loss = criterion(outputs, target)
|
|
|
|
# 反向传播
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
running_loss += loss.item()
|
|
if i % 10 == 9: # 每10个batch打印一次
|
|
print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 10:.3f}')
|
|
running_loss = 0.0
|
|
|
|
return model
|
|
|
|
def main():
|
|
# 直接连接到已运行的服务器
|
|
client = sy.login(
|
|
url="localhost:8093",
|
|
email="researcher@cifar10.research",
|
|
password="syftrocks"
|
|
)
|
|
|
|
try:
|
|
# 获取数据集
|
|
dataset = client.datasets["CIFAR10 Training Dataset"]
|
|
print(f"Retrieved dataset: {dataset.name}")
|
|
|
|
# 获取资产
|
|
images, labels = dataset.assets
|
|
print(f"Retrieved assets: {images.name}, {labels.name}")
|
|
|
|
# 使用模拟数据
|
|
train_images = images.mock
|
|
train_labels = labels.mock
|
|
|
|
# 创建项目
|
|
project_description = """
|
|
The purpose of this study is to train a CNN model on CIFAR10 data.
|
|
The model architecture includes two convolutional layers and three fully connected layers.
|
|
We will evaluate the model's performance on the training data.
|
|
"""
|
|
|
|
research_project = client.create_project(
|
|
name="CIFAR10 CNN Project",
|
|
description=project_description,
|
|
user_email_address="researcher@cifar10.research"
|
|
)
|
|
|
|
# 创建远程函数
|
|
remote_train_function = sy.syft_function_single_use(
|
|
images=train_images,
|
|
labels=train_labels
|
|
)(train_cnn_model)
|
|
|
|
# 创建代码请求
|
|
code_request = research_project.create_code_request(
|
|
remote_train_function,
|
|
client
|
|
)
|
|
|
|
# 等待数据所有者批准请求
|
|
print("Waiting for data owner to approve the request...")
|
|
input("Press Enter after the data owner has approved the request...")
|
|
|
|
# 执行训练
|
|
model = client.code.train_cnn_model(
|
|
images=train_images,
|
|
labels=train_labels
|
|
)
|
|
|
|
print('Training completed successfully!')
|
|
|
|
except Exception as e:
|
|
print(f"Error occurred: {str(e)}")
|
|
|
|
if __name__ == '__main__':
|
|
main() |