跟我学ModelArts丨探索ModelArts平台个性化联邦学习API

随着数字技术的发展,以及全社会对数字化的不断重视,数据的资源属性前所未有地突显出来。相应地,数据隐私和数据安全也越来越受到人们的关注,联邦学习应运而生,成为当前比较热门的一个AI算法发展方向。

什么是联邦学习:

联邦学习(Federated Learning)是一种新兴的人工智能基础技术,在 2016 年由谷歌最先提出,原本用于解决安卓手机终端用户在本地更新模型的问题,其设计目标是在保障大数据交换时的信息安全、保护终端数据和个人数据隐私、保证合法合规的前提下,在多参与方或多计算结点之间开展高效率的机器学习。

金融运用领域前景:

目前在金融领域,各个金融机构都会建设基于自己的业务场景风控模型,当运用了联邦学习即可基于各自的风控模型建立联合模型,就能更准确地识别信贷风险,金融欺诈。同时共同建立联邦学习模型,还能解决原数据样本少、数据质量低的问题。

如何实战联邦学习:

ModelArts提供了一个实现个性化联邦学习的API——pytorch_fedamp_emnist_classification,它主要是让拥有相似数据分布的客户进行更多合作的一个横向联邦学习框架,让我们来对它进行一些学习和探索。

1. 环境准备

1.1. 导入文件操作模块和输出清理模块

import os
import shutil
from IPython.display import clear_output

1.2. 下载联邦学习包并清除输出

1.3. 如果存在FedAMP文件夹,则把它完整地删除,然后重新创建FedAMP文件夹

if os.path.exists('FedAMP'):
    shutil.rmtree('FedAMP')
!mkdir FedAMP

1.4. 把下载的联邦学习包解压到该文件夹,删除压缩包,并清理输出

!unzip FedAMP.zip -d FedAMP
!rm FedAMP.zip
clear_output(wait=False)

1.5. 安装基于Pytorch的图像处理模块torchvision,并清理输出

!pip install torchvision==0.5.0
clear_output(wait=False)

1.6. 安装torch框架并清理输出

!pip install torch==1.4.0
clear_output(wait=False)

1.7. 安装联邦学习包,删除原文件,并清理输出

!pip install FedAMP/package/moxing_pytorch-1.17.3.fed-cp36-cp36m-linux_x86_64.whl
!rm -r FedAMP/package
clear_output(wait=False)

1.8. 导入torch框架、numpyrandommatplotlib.pyplot、华为moxing

import torch, random
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import moxing as mox

1.9. 导入华为moxing框架下的联邦算法、支持和服务

from moxing.framework.federated import fed_algorithm
from moxing.framework.federated import fed_backend
from moxing.framework.federated import fed_server

1.10. 导入华为moxing torch框架下联邦学习的的载入、保存、hook和服务

from moxing.pytorch.executor.federated.util import torch_load
from moxing.pytorch.executor.federated.util import torch_save
from moxing.pytorch.executor.federated.util import TorchFedHook
from moxing.pytorch.executor.federated import client

1.11. 导入FedAMPtorch.nn下的函数包

from moxing.framework.federated.fed_algorithm import FedAMP
import torch.nn.functional as F

1.12. 准备好文件路径

if mox.file.is_directory('/tmp/fed_workspace/'):
    mox.file.remove('/tmp/fed_workspace/', recursive=True)

2. 建立数据读取类和数据结构类(具体内容将在下文用到时说明)

class DataFileLoaderHorizontal():
    def __init__(self, data=None, label=None):
        if data is None:
            self.data = data
        if label is None:
            self.label = label
    def getDataToTorch(self):
        return torch.FloatTensor(self.data), torch.FloatTensor(self.label)
    def load_file_binary(self, data_filename=None, label_filename=None):
        assert data_filename is not None
        assert label_filename is not None
        self.data = np.load(data_filename, allow_pickle=True)
        self.label = np.load(label_filename, allow_pickle=True)
        self.data, self.label = self.data.astype(float), self.label.astype(float)
class m_Data():
    def __init__(self):
        self.train_samples          = None
        self.train_labels           = None
        self.test_samples           = None
        self.train_samples          = None

3. 将数据读入对应虚拟租户

3.1. 设置虚拟租户数量

num_clients = 62

3.2. 创建一个数据读取类

df1 = DataFileLoaderHorizontal()

3.3. 初始化训练集、测试集文件名和文件扩展名

rain_sample_filename = 'FedAMP/EMNIST/client_train_samples_'
train_label_filename = 'FedAMP/EMNIST/client_train_labels_'
val_sample_filename = 'FedAMP/EMNIST/client_test_samples_'
val_label_filename = 'FedAMP/EMNIST/client_test_labels_'
filename_sx = '.npy'

3.4. 让我们来探索一下训练数据集

3.4.1. 先导入一个样本集

df1.load_file_binary(data_filename=train_sample_filename + str(1) + filename_sx,
                         label_filename=train_label_filename + str(1) + filename_sx)

这里使用了“2.”DataFileLoaderHorizontal类的load_file_binary方法,该方法首先确认传入的文件名不为空,然后numpyload方法将.npy文件载入,最后用astype方法将其转为float类型

3.4.2. 先看一下自变量

df1.data

array([[[0., 0., 0., ..., 0., 0., 0.],

        [0., 0., 0., ..., 0., 0., 0.],

        [0., 0., 0., ..., 0., 0., 0.],

        ...,

        [0., 0., 0., ..., 0., 0., 0.],

        [0., 0., 0., ..., 0., 0., 0.],

        [0., 0., 0., ..., 0., 0., 0.]],

 

       ...,

 

       [[0., 0., 0., ..., 0., 0., 0.],

        [0., 0., 0., ..., 0., 0., 0.],

        [0., 0., 0., ..., 0., 0., 0.],

        ...,

        [0., 0., 0., ..., 0., 0., 0.],

        [0., 0., 0., ..., 0., 0., 0.],

        [0., 0., 0., ..., 0., 0., 0.]],

 

       [[0., 0., 0., ..., 0., 0., 0.],

        [0., 0., 0., ..., 0., 0., 0.],

        [0., 0., 0., ..., 0., 0., 0.],

        ...,

        [0., 0., 0., ..., 0., 0., 0.],

        [0., 0., 0., ..., 0., 0., 0.],

        [0., 0., 0., ..., 0., 0., 0.]]])

df1.data[0]

array([[0.        , 0.        , 0.        , 0.        , 0.        ,

  1. , 0.        , 0.        , 0.        , 0.        ,
  2. , 0.        , 0.        , 0.        , 0.        ,
  3. , 0.        , 0.        , 0.        , 0.        ,
  4. , 0.        , 0.        , 0.        , 0.        ,
  5. , 0.        , 0.        ],

       [0.        , 0.        , 0.        , 0.        , 0.        ,

  1. , 0.        , 0.        , 0.        , 0.        ,
  2. , 0.        , 0.        , 0.        , 0.        ,
  3. , 0.        , 0.        , 0.        , 0.        ,
  4.         , 0.        , 0.        , 0.        , 0.        ,
  5. , 0.        , 0.        ],

………………

 

       [0.        , 0.        , 0.        , 0.        , 0.        ,

  1. , 0.        , 0.        , 0.        , 0.        ,
  2. , 0.        , 0.        , 0.        , 0.        ,
  3. , 0.        , 0.        , 0.        , 0.        ,
  4. , 0.        , 0.        , 0.        , 0.        ,
  5. , 0.        , 0.        ],

       [0.        , 0.        , 0.        , 0.        , 0.        ,

  1. , 0.        , 0.        , 0.        , 0.        ,
  2. , 0.        , 0.        , 0.        , 0.        ,
  3. , 0.        , 0.        , 0.        , 0.        ,
  4.         , 0.        , 0.        , 0.        , 0.        ,
  5. , 0.        , 0.        ]])
len(df1.data)

1000

可以看到,它是一个比较稀疏的三维数组,由1000个二维数组组成

3.4.3. 再看一下标签集

df1.label

array([40.,  7.,  5.,  4.,  8.,  8.,  0.,  9.,  3., 34.,  2.,  2.,  8.,

        6.,  5.,  3.,  9.,  9.,  6.,  8.,  5.,  6.,  6.,  6.,  6.,  5.,

        5.,  8.,  5.,  6.,  7.,  5.,  8., 59.,  2.,  9.,  7.,  6.,  3.,

        4., 57.,  7.,  9., 49., 52., 25.,  4.,  2., 43.,  6.,  9.,  5.,

        3.,  5.,  7.,  7.,  3.,  0.,  6.,  7.,  5., 27.,  9., 24.,  2.,

        2.,  7.,  6.,  1.,  9., 45.,  7.,  0., 14.,  9.,  9.,  0.,  2.,

        6.,  5.,  1.,  4.,  6., 36.,  8.,  0., 34.,  0.,  0., 53.,  5.,

        0.,  2.,  7., 52., 32.,  2.,  4., 35., 49., 15.,  2., 60.,  8.,

        0.,  7., 51., 19.,  3.,  1., 24.,  9.,  2.,  2.,  4.,  8.,  8.,

        4.,  3.,  0.,  9.,  7.,  6., 26.,  7.,  4.,  7.,  7.,  2.,  8.,

        9.,  4.,  1.,  4.,  9.,  8.,  3., 16.,  0.,  5.,  3., 16.,  5.,

        3.,  1.,  7., 19., 53.,  4.,  0.,  2.,  5., 23., 19., 46.,  5.,

        2.,  7.,  4., 51., 57.,  7., 16.,  2.,  1.,  0.,  2.,  4.,  0.,

       41., 21.,  8.,  1.,  2., 39.,  3.,  1.,  6., 58., 32.,  3.,  9.,

        6.,  4.,  3., 54.,  7.,  3., 60.,  7.,  8.,  3.,  3.,  2.,  2.,

        5., 60.,  5.,  5.,  6.,  7.,  9.,  9.,  2.,  8.,  3., 43.,  5.,

        1.,  9.,  5.,  9., 13.,  6.,  7.,  6.,  6., 59.,  0.,  8.,  7.,

        7.,  2., 57.,  4.,  8.,  3.,  4.,  6.,  4.,  3.,  9.,  8.,  0.,

        6., 48.,  0.,  4.,  2.,  3.,  4.,  8., 18.,  2.,  2.,  4., 30.,

        7.,  2.,  9.,  7.,  1.,  1.,  2., 20., 36.,  9.,  5., 32.,  3.,

        3.,  3.,  3.,  3., 20., 37.,  1., 25.,  1.,  0., 57.,  2.,  2.,

        0.,  3.,  9.,  2., 18.,  2.,  3., 40., 28.,  1.,  4.,  2.,  8.,

        4.,  8.,  5.,  0., 18.,  0.,  1.,  2.,  7.,  8.,  6.,  0.,  2.,

        5., 35.,  0.,  1., 53.,  2.,  3.,  3.,  2.,  8., 32.,  3.,  5.,

        6.,  8.,  2.,  7., 40.,  8.,  5.,  6.,  8.,  4.,  9.,  1., 13.,

        6.,  3.,  3.,  5.,  3., 51., 60.,  2.,  3., 40.,  1.,  0., 47.,

       59.,  9.,  6.,  1.,  2.,  1.,  9.,  8.,  0.,  3.,  8., 53., 61.,

        8.,  5., 18.,  7.,  0.,  4.,  1.,  1., 51.,  0.,  9., 43.,  6.,

       51.,  5.,  7., 22., 24., 42.,  3., 47.,  0., 59.,  7., 42.,  7.,

       58.,  7.,  1.,  0.,  4.,  8.,  8.,  8., 20.,  1., 16.,  9.,  0.,

        3., 23.,  6.,  4., 45.,  5.,  0.,  1.,  2.,  9.,  1., 27.,  9.,

        5.,  4.,  7.,  7.,  0., 15.,  3.,  9., 36.,  9., 47.,  3., 29.,

       56., 42.,  2.,  7., 42.,  4.,  1.,  9.,  0., 34.,  3.,  5.,  0.,

       15.,  0.,  6.,  4.,  7.,  4.,  5.,  0., 15.,  9.,  8., 43.,  7.,

        7.,  6., 42.,  6.,  8.,  7., 61.,  2.,  8.,  1.,  5.,  7., 57.,

        2., 23.,  9.,  4.,  1., 59.,  3.,  1.,  9.,  9., 15.,  5., 47.,

       27.,  6.,  6.,  0.,  4.,  2.,  3.,  2., 22.,  3.,  6.,  2.,  6.,

        5.,  8.,  7.,  9.,  7.,  3., 49.,  5.,  5.,  1.,  6.,  8.,  0.,

        6.,  7., 45.,  4.,  6.,  3.,  9.,  5.,  0., 12., 18.,  8.,  4.,

        3.,  4.,  6.,  6.,  4.,  5.,  3., 29.,  7.,  7.,  5.,  9.,  7.,

        4.,  0.,  6.,  8.,  5.,  2.,  8.,  1.,  9.,  8.,  7., 25.,  1.,

        6.,  8.,  4.,  9.,  3.,  1.,  2.,  9.,  2.,  5.,  1.,  9.,  5.,

        1.,  2.,  1.,  5., 24., 45.,  7.,  0.,  4.,  8., 49.,  9.,  6.,

        4.,  2., 35.,  4.,  9.,  8.,  7.,  8.,  1.,  6.,  1.,  7.,  9.,

        1.,  8.,  1.,  1.,  3.,  0., 17., 47.,  6.,  0.,  3.,  2.,  5.,

        5., 55., 28.,  9., 56.,  7.,  8.,  2.,  2., 50.,  8.,  4.,  9.,

        4.,  3.,  1.,  1.,  0.,  5., 38.,  8.,  9.,  0.,  1.,  5.,  2.,

       25.,  5.,  0.,  4.,  7.,  9.,  7., 61.,  4.,  4.,  2.,  2.,  6.,

       41., 45., 20.,  5.,  8.,  5.,  8.,  7.,  9.,  4.,  3.,  1.,  7.,

       19.,  3.,  8.,  1.,  9.,  7., 27.,  3.,  0.,  4.,  8.,  8.,  2.,

       46.,  6.,  6.,  5.,  1.,  8.,  6.,  8.,  2.,  4.,  5., 33.,  5.,

        5.,  5.,  8.,  0.,  2., 31.,  5.,  1.,  7.,  1.,  5., 48., 41.,

        9.,  4., 61.,  9.,  9., 34., 16.,  7.,  5.,  0.,  5., 32.,  0.,

       52.,  3.,  1.,  4.,  6., 29.,  4.,  2.,  0.,  4.,  0.,  1., 48.,

        3.,  9.,  5.,  1.,  7.,  6.,  4.,  4.,  5.,  8.,  8.,  9.,  1.,

       46.,  0., 29.,  0.,  5.,  4.,  4., 48., 56.,  9.,  3.,  1.,  3.,

        1.,  5.,  7.,  9.,  8.,  8.,  6.,  6.,  0.,  8.,  0., 53.,  1.,

        6.,  1.,  4.,  4.,  8., 11.,  9.,  8.,  1., 44.,  4.,  2.,  1.,

        3.,  7.,  6.,  2., 39.,  8.,  9.,  4.,  6.,  4.,  1.,  2.,  7.,

       33.,  4., 36.,  3., 40.,  1.,  8.,  5.,  3.,  3.,  3., 28., 13.,

        9.,  1., 46.,  1.,  5., 22.,  0.,  9.,  0.,  0.,  2.,  1.,  2.,

       43.,  7.,  4.,  0.,  2., 28., 39., 48.,  4.,  0.,  5.,  3.,  6.,

        6.,  7., 19.,  6.,  4.,  0., 35., 13.,  3., 28.,  2.,  6., 23.,

        2.,  5.,  1.,  0.,  8.,  8.,  2., 10., 27.,  0., 49., 58., 23.,

        9.,  2.,  7.,  7.,  2.,  9.,  5.,  4.,  9., 22.,  5.,  8.,  6.,

        4., 58.,  6.,  5.,  4.,  9.,  1.,  7.,  0.,  3., 33.,  3.,  7.,

        9.,  6.,  3.,  1.,  1.,  6.,  2.,  1.,  2.,  7.,  3.,  7.,  8.,

        6.,  0.,  4., 34., 41.,  8.,  3.,  6.,  8.,  6.,  1.,  6.,  3.,

       56., 24.,  0.,  0.,  1., 58.,  0.,  1.,  9., 29.,  8.,  9.,  6.,

        6.,  8.,  9.,  1., 39.,  3.,  0.,  4., 25.,  8., 33.,  0.,  2.,

        3.,  7.,  5.,  0.,  7.,  7.,  6., 46.,  7.,  8.,  6.,  2.,  0.,

        8.,  7.,  5., 20., 56.,  9.,  4., 41.,  9.,  8.,  4., 13.,  5.,

        3., 61.,  4.,  5.,  1., 33.,  0.,  1.,  7.,  1.,  0.,  6.,  3.,

        6.,  2.,  6.,  4., 22.,  5.,  4., 36.,  0.,  9.,  2.,  9.,  3.,

        2.,  0.,  0.,  7.,  2., 35.,  5.,  9.,  4.,  4.,  0.,  6.,  6.,

        9.,  5.,  5., 39.,  3.,  1., 60.,  4., 52.,  6.,  4.,  0.,  1.,

        6.,  9.,  8., 52.,  3.,  1.,  7.,  3.,  3.,  9.,  7.,  8.])

可以推测,每一份训练集由1000个样本组成,自变量为二维数组

3.4.4. 将样本集和标签集转化为torch.FloatTensor类型

samples, labels = df1.getDataToTorch()

3.5. 让我们来创建一个“2.0”m_Data的实例,并将训练样本集和标签集导入m_Data

3.5.1. 先来创建一个m_Data

m_data = m_Data()

3.5.2. 初始化输入格式

input_dim = (-1, 1, 28, 28)

3.5.3. 创建m_data的训练集

m_data.train_samples = samples.reshape(input_dim)
m_data.train_labels = labels.squeeze()

3.5.4. 创建m_data的测试集

df1.load_file_binary(data_filename=val_sample_filename + str(1) + filename_sx,
                     label_filename=val_label_filename + str(1) + filename_sx)
samples, labels = df1.getDataToTorch()
m_data.val_samples = samples.reshape(input_dim)
m_data.val_labels = labels.squeeze()

在此,我们对比m_Data的数据结构,可以发现,m_Data的数据结构中似乎有个小bug,尽管它不影响使用 class m_Data(): def init(self): self.train_samples = None self.train_labels = None self.test_samples = None self.train_samples = None 这里的test_samples应该是val_samples,最后一个train_samples应该是val_labels

pytorch_fedamp_emnist_classification的第一次学习让我们先学到这里。我们已经在ModelAts上实现了pytorch_fedamp_emnist_classification的环境配置,对样本数据结构以及pytorch_fedamp_emnist_classification需要的数据结构进行了简单地探索。

下周末继续。

(完)