pytorch 两个网络联合训练,pytorch 训练
本文主要介绍PyTorch实现联邦学习的基本算法FedAvg。有需要的朋友可以借鉴一下,希望能有所帮助。祝大家进步很大,早日升职加薪。
00-1010i。前言二。数据介绍特征构造。联邦学习1。总体框架2。服务器3。客户四。代码实现1。初始化2。服务器3。客户端4。实验和结果。源代码和数据
目录
在之前的博客中,使用numpy手持神经网络实现了联邦学习的基本算法FedAvg。手持神经网络的效果已经很好了,但还是属于自己做轮子。建议先用PyTorch实现。
I. 前言
联合学习有很多客户端,每个客户端都有自己的数据集,他们不愿意分享。
本文选取的数据集是2016-2019年中国北方某市十个区/县的真实用电负荷数据。采集时间间隔为1小时,即每天有24个载荷值。
我们假设这10个地区的电力部门不愿意分享自己的数据,但是他们希望得到一个由所有数据训练出来的全局模型。
除了电力负荷数据,还有一个替代数据集:风力发电数据集。参数类型指定了两个数据集:type==load 表示负载数据,而 wind 表示风力数据。
II. 数据介绍
使用某一时间前24次的负荷值和相关气象数据(如温度、湿度、压力等。)来预测当时的负载值。
对于风电数据,也是利用某一时刻前24次的风电功率值和当时的相关气象数据来预测当时的风电功率值。
每个地区应该就如何制定功能集达成一致。本文使用的各地区数据特征一致,可以直接使用。
特征构造
III. 联邦学习
原论文中提出的FedAvg的框架是:
客户端由PyTorch构建:
类ANN(nn。模块):
def __init__(self,input_dim,name,B,E,type,lr):
超级(安,自我)。__init__()
self.name=name
自我。B=B
自我。E=E
self.len=0
self.type=type
self.lr=lr
self.loss=0
self.fc1=nn。线性(input_dim,20)
self.relu=nn。ReLU()
self.sigmoid=nn。乙状结肠()
self.dropout=nn。辍学()
self.fc2=nn。线性(20,20)
self.fc3=nn。线性(20,20)
self.fc4=nn。线性(20,1)
定义转发(自身,数据):
x=self.fc1(数据)
x=self.sigmoid(x)
x=self.fc2(x)
x=self.sigmoid(x)
x=self.fc3(x)
x=self.sigmoid(x)
x=self.fc4(x)
x=self.sigmoid(x)
返回
x
2. 服务器端
服务器端执行以下步骤:
简单来说,每一轮通信时都只是选择部分客户端,这些客户端利用本地的数据进行参数更新,然后将更新后的参数传给服务器,服务器汇总客户端更新后的参数形成最新的全局参数。下一轮通信时,服务器端将最新的参数分发给被选中的客户端,进行下一轮更新。
3. 客户端
客户端没什么可说的,就是利用本地数据对神经网络模型的参数进行更新。
IV. 代码实现
1. 初始化
class FedAvg:def __init__(self, options):
self.C = options[C]
self.E = options[E]
self.B = options[B]
self.K = options[K]
self.r = options[r]
self.input_dim = options[input_dim]
self.type = options[type]
self.lr = options[lr]
self.clients = options[clients]
self.nn = ANN(input_dim=self.input_dim, name=server, B=B, E=E, type=self.type, lr=self.lr).to(device)
self.nns = []
for i in range(K):
temp = copy.deepcopy(self.nn)
temp.name = self.clients[i]
self.nns.append(temp)
参数:
- K,客户端数量,本文为10个,也就是10个地区。
- C:选择率,每一轮通信时都只是选择C * K个客户端。
- E:客户端更新本地模型的参数时,在本地数据集上训练E轮。
- B:客户端更新本地模型的参数时,本地数据集batch大小为B
- r:服务器端和客户端一共进行r轮通信。
- clients:客户端集合。
- type:指定数据类型,负荷预测or风功率预测。
- lr:学习率。
- input_dim:数据输入维度。
- nn:全局模型。
- nns: 客户端模型集合。
2. 服务器端
服务器端代码如下:
def server(self):for t in range(self.r):
print(第, t + 1, 轮通信:)
m = np.max([int(self.C * self.K), 1])
# sampling
index = random.sample(range(0, self.K), m)
# dispatch
self.dispatch(index)
# local updating
self.client_update(index)
# aggregation
self.aggregation(index)
# return global model
return self.nn
其中client_update(index):
def client_update(self, index): # update nnfor k in index:
self.nns[k] = train(self.nns[k])
aggregation(index):
def aggregation(self, index):s = 0
for j in index:
# normal
s += self.nns[j].len
params = {}
with torch.no_grad():
for k, v in self.nns[0].named_parameters():
params[k] = copy.deepcopy(v)
params[k].zero_()
for j in index:
with torch.no_grad():
for k, v in self.nns[j].named_parameters():
params[k] += v * (self.nns[j].len / s)
with torch.no_grad():
for k, v in self.nn.named_parameters():
v.copy_(params[k])
dispatch(index):
def dispatch(self, index):params = {}
with torch.no_grad():
for k, v in self.nn.named_parameters():
params[k] = copy.deepcopy(v)
for j in index:
with torch.no_grad():
for k, v in self.nns[j].named_parameters():
v.copy_(params[k])
下面对重要代码进行分析:
客户端的选择
m = np.max([int(self.C * self.K), 1])index = random.sample(range(0, self.K), m)
index中存储中m个0~10间的整数,表示被选中客户端的序号。
客户端的更新
for k in index:self.client_update(self.nns[k])
服务器端汇总客户端模型的参数
关于模型汇总方式,可以参考一下我的另一篇文章:对FedAvg中模型聚合过程的理解。
当然,这只是一种很简单的汇总方式,还有一些其他类型的汇总方式。
论文Electricity Consumer Characteristics Identification: A Federated Learning Approach中总结了三种汇总方式:
normal:原始论文中的方式,即根据样本数量来决定客户端参数在最终组合时所占比例。
LA:根据客户端模型的损失占所有客户端损失和的比重来决定最终组合时参数所占比例。
LS:根据损失与样本数量的乘积所占的比重来决定。 将更新后的参数分发给被选中的客户端
def dispatch(self, index):params = {}
with torch.no_grad():
for k, v in self.nn.named_parameters():
params[k] = copy.deepcopy(v)
for j in index:
with torch.no_grad():
for k, v in self.nns[j].named_parameters():
v.copy_(params[k])
3. 客户端
客户端只需要利用本地数据来进行更新就行了:
def client_update(self, index): # update nnfor k in index:
self.nns[k] = train(self.nns[k])
其中train():
def train(ann):ann.train()
# print(p)
if ann.type == load:
Dtr, Dte = nn_seq(ann.name, ann.B, ann.type)
else:
Dtr, Dte = nn_seq_wind(ann.named, ann.B, ann.type)
ann.len = len(Dtr)
# print(len(Dtr))
loss_function = nn.MSELoss().to(device)
loss = 0
optimizer = torch.optim.Adam(ann.parameters(), lr=ann.lr)
for epoch in range(ann.E):
cnt = 0
for (seq, label) in Dtr:
cnt += 1
seq = seq.to(device)
label = label.to(device)
y_pred = ann(seq)
loss = loss_function(y_pred, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(epoch, epoch, :, loss.item())
return ann
4. 测试
def global_test(self):model = self.nn
model.eval()
c = clients if self.type == load else clients_wind
for client in c:
model.name = client
test(model)
V. 实验及结果
本次实验的参数选择为:
if __name__ == __main__:K, C, E, B, r = 10, 0.5, 50, 50, 5
type = load
input_dim = 30 if type == load else 28
_client = clients if type == load else clients_wind
lr = 0.08
options = {K: K, C: C, E: E, B: B, r: r, type: type, clients: _client,
input_dim: input_dim, lr: lr}
fedavg = FedAvg(options)
fedavg.server()
fedavg.global_test()
各个客户端单独训练(训练50轮,batch大小为50)后在本地的测试集上的表现为:
可以看到,由于各个客户端的数据都十分充足,所以每个客户端自己训练的本地模型的预测精度已经很高了。
服务器与客户端通信5轮后,服务器上的全局模型在10个客户端测试集上的表现如下所示:
可以看到,经过联邦学习框架得到全局模型在各个客户端上表现同样很好ÿ0c;这是因为十个地区上的数据分布类似。
给出numpy和PyTorch的对比:
客户端编号 1 2 3 4 5 6 7 8 9 10
同样本地模型的效果是最好的,PyTorch搭建的网络和numpy搭建的网络效果差不多,但推荐使用PyTorch,不要造轮子。
VI. 源码及数据
我把数据和代码放在了GitHub上:源码及数据,原创不易,下载时请随手给个follow和star,感谢!
以上就是PyTorch实现联邦学习的基本算法FedAvg的详细内容,更多关于PyTorch实现FedAvg算法的资料请关注盛行IT软件开发工作室其它相关文章!
郑重声明:本文由网友发布,不代表盛行IT的观点,版权归原作者所有,仅为传播更多信息之目的,如有侵权请联系,我们将第一时间修改或删除,多谢。