Browse Source

Add llm and fl beta code

naibo 10 months ago
parent
commit
5180f47b70
2 changed files with 144 additions and 0 deletions
  1. 108 0
      ExecuteStage/fl_beta.py
  2. 36 0
      ExecuteStage/llm_beta.py

+ 108 - 0
ExecuteStage/fl_beta.py

@@ -0,0 +1,108 @@
+import torch
+import torch.nn as nn
+import torch.optim as optim
+from torchvision import models, transforms
+from torch.utils.data import DataLoader, Dataset
+import numpy as np
+from PIL import Image
+import os
+
+# 定义 ResNet 模型(以 ResNet18 为例)
+class ResNetModel(nn.Module):
+    def __init__(self, num_classes):
+        super(ResNetModel, self).__init__()
+        self.resnet = models.resnet18(pretrained=True)
+        # 修改最后的全连接层以适应特定的分类任务
+        self.resnet.fc = nn.Linear(self.resnet.fc.in_features, num_classes)
+
+    def forward(self, x):
+        return self.resnet(x)
+
+# 自定义数据集类
+class WebpageDataset(Dataset):
+    def __init__(self, image_dir, transform=None):
+        self.image_dir = image_dir
+        self.transform = transform
+        self.image_files = [f for f in os.listdir(image_dir) if f.endswith('.png')]
+
+    def __len__(self):
+        return len(self.image_files)
+
+    def __getitem__(self, idx):
+        img_name = os.path.join(self.image_dir, self.image_files[idx])
+        image = Image.open(img_name).convert('RGB')
+        label = self.get_label_from_filename(self.image_files[idx])
+        if self.transform:
+            image = self.transform(image)
+        return image, label
+
+    def get_label_from_filename(self, filename):
+        # 假设文件名格式为 'class_label.png'
+        return int(filename.split('_')[0])
+
+# 图像预处理
+transform = transforms.Compose([
+    transforms.Resize((224, 224)),
+    transforms.ToTensor(),
+    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
+])
+
+# 定义客户端训练函数
+def train_local_model(model, dataloader, criterion, optimizer, epochs=5):
+    model.train()
+    for epoch in range(epochs):
+        for images, labels in dataloader:
+            outputs = model(images)
+            loss = criterion(outputs, labels)
+            optimizer.zero_grad()
+            loss.backward()
+            optimizer.step()
+    return model.state_dict()
+
+# 联邦平均算法
+def federated_average(models_state_dicts):
+    avg_state_dict = models_state_dicts[0]
+    for key in avg_state_dict.keys():
+        for i in range(1, len(models_state_dicts)):
+            avg_state_dict[key] += models_state_dicts[i][key]
+        avg_state_dict[key] = torch.div(avg_state_dict[key], len(models_state_dicts))
+    return avg_state_dict
+
+# 模拟多个客户端的数据
+client_data_dirs = ['client1_data', 'client2_data', 'client3_data']  # 每个客户端的数据目录
+num_classes = 10  # 根据实际情况设置
+
+# 初始化全局模型
+global_model = ResNetModel(num_classes=num_classes)
+
+# 定义损失函数
+criterion = nn.CrossEntropyLoss()
+
+# 联邦学习过程
+num_rounds = 10
+for round in range(num_rounds):
+    local_models = []
+    for client_dir in client_data_dirs:
+        # 加载客户端数据
+        dataset = WebpageDataset(image_dir=client_dir, transform=transform)
+        dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
+
+        # 初始化客户端模型
+        local_model = ResNetModel(num_classes=num_classes)
+        local_model.load_state_dict(global_model.state_dict())
+
+        # 定义优化器
+        optimizer = optim.SGD(local_model.parameters(), lr=0.01, momentum=0.9)
+
+        # 训练本地模型
+        local_state_dict = train_local_model(local_model, dataloader, criterion, optimizer)
+        local_models.append(local_state_dict)
+
+    # 聚合模型参数
+    global_state_dict = federated_average(local_models)
+    global_model.load_state_dict(global_state_dict)
+
+    print(f'Round {round+1}/{num_rounds} completed.')
+
+# 保存全局模型
+torch.save(global_model.state_dict(), 'federated_resnet_model.pth')

+ 36 - 0
ExecuteStage/llm_beta.py

@@ -0,0 +1,36 @@
+from transformers import AutoProcessor, AutoModelForVision2Seq
+from PIL import Image
+import torch
+
+# 加载 Llama 3.2 视觉模型和处理器
+model_name = "meta-llama/Llama-3.2-11B-Vision"  # 请根据实际模型路径替换
+processor = AutoProcessor.from_pretrained(model_name)
+model = AutoModelForVision2Seq.from_pretrained(model_name)
+
+# 处理网页截图并提取结构
+def predict_structure_from_image(image_path):
+    # 加载图像
+    image = Image.open(image_path).convert("RGB")
+
+    # 预处理图像
+    inputs = processor(images=image, return_tensors="pt")
+
+    # 生成描述(结构描述)
+    outputs = model.generate(
+        inputs["pixel_values"],
+        max_length=512,
+        num_beams=5,
+        early_stopping=True
+    )
+    description = processor.decode(outputs[0], skip_special_tokens=True)
+    return description
+
+# 示例使用
+if __name__ == "__main__":
+    # 提供网页截图的路径
+    image_path = "webpage_screenshot.png"  # 请替换为实际的图像文件路径
+
+    # 预测结构
+    predicted_structure = predict_structure_from_image(image_path)
+
+    print("预测的结构:", predicted_structure)