import syft as sy import torch import torch.nn as nn import torch.optim as optim class SimpleCNN(nn.Module): def __init__(self): super(SimpleCNN, self).__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16 * 5 * 5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): x = self.pool(torch.relu(self.conv1(x))) x = self.pool(torch.relu(self.conv2(x))) x = x.view(-1, 16 * 5 * 5) x = torch.relu(self.fc1(x)) x = torch.relu(self.fc2(x)) x = self.fc3(x) return x def train_cnn_model(images, labels): # 创建模型 model = SimpleCNN() criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) # 训练模型 print("Starting training...") for epoch in range(2): # 只训练2个epoch作为示例 running_loss = 0.0 for i in range(len(images)): # 获取数据 inputs = images[i].unsqueeze(0) target = labels[i].unsqueeze(0) # 前向传播 optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, target) # 反向传播 loss.backward() optimizer.step() running_loss += loss.item() if i % 10 == 9: # 每10个batch打印一次 print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 10:.3f}') running_loss = 0.0 return model def main(): # 直接连接到已运行的服务器 client = sy.login( url="localhost:8093", email="researcher@cifar10.research", password="syftrocks" ) try: # 获取数据集 dataset = client.datasets["CIFAR10 Training Dataset"] print(f"Retrieved dataset: {dataset.name}") # 获取资产 images, labels = dataset.assets print(f"Retrieved assets: {images.name}, {labels.name}") # 使用模拟数据 train_images = images.mock train_labels = labels.mock # 创建项目 project_description = """ The purpose of this study is to train a CNN model on CIFAR10 data. The model architecture includes two convolutional layers and three fully connected layers. We will evaluate the model's performance on the training data. """ research_project = client.create_project( name="CIFAR10 CNN Project", description=project_description, user_email_address="researcher@cifar10.research" ) # 创建远程函数 remote_train_function = sy.syft_function_single_use( images=train_images, labels=train_labels )(train_cnn_model) # 创建代码请求 code_request = research_project.create_code_request( remote_train_function, client ) # 等待数据所有者批准请求 print("Waiting for data owner to approve the request...") input("Press Enter after the data owner has approved the request...") # 执行训练 model = client.code.train_cnn_model( images=train_images, labels=train_labels ) print('Training completed successfully!') except Exception as e: print(f"Error occurred: {str(e)}") if __name__ == '__main__': main()