Python工具箱

记录常用的一些工具代码

遍历某文件夹下的所有文件路径(递归)

可以用来对某个数据集进行批量处理。这里只返回所有文件的路径, 并存储到一个list.txt文件中。

def get_file_path(path, txt: list):
    dir_path = os.listdir(path)
    for dp in dir_path:
        if os.path.isdir(os.path.join(path, dp)):
            get_file_path(os.path.join(path, dp), txt)
        else:
            txt.append(os.path.join(path, dp))

np.savez保存后的读取

val_set_all = dict(np.load('savemodel/fine_tune_test_set.npz', allow_pickle=True))
for name in val_set_all.keys():
    val_set_all[name] = val_set_all[name][()]

JS散度

def loss_js(p_output, q_output, get_softmax=True):
    KLDivLoss = nn.KLDivLoss(reduction='batchmean')
    if get_softmax:
        p_output = F.softmax(p_output, dim=1)
        q_output = F.softmax(q_output, dim=1)
    mean_output = (p_output + q_output)/2
    return (KLDivLoss(p_output.log(), mean_output) + KLDivLoss(q_output.log(), mean_output))/2

pytorch 设置随机种子

def seed_torch(seed=1029):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed) # 为了禁止hash随机化,使得实验可复现
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

Pytorch梯度裁剪

loss.backward()
nn.utils.clip_grad_norm_(net1.parameters(), max_norm=20, norm_type=2)
nn.utils.clip_grad_norm_(net2.parameters(), max_norm=20, norm_type=2)
optimizer.step()

matplotlib写中文

plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False 

多个list按其中一个排序,其余的跟着变

u_idx = [4, 2, 3, 1, 0]
label = [0, 3, 4, 1, 2]
print([list(x) for x in zip(*sorted(zip(u_idx, label), key=lambda x: x[0]))][1])

绘制混淆矩阵

from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import numpy as np


def plot_confusion_matrix(cm, labels, title='Confusion Matrix', cmap=plt.cm.Blues):
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    xlocations = np.array(range(len(labels)))
    plt.xticks(xlocations, labels, rotation=90)
    plt.yticks(xlocations, labels)
    plt.ylabel('True label')
    plt.xlabel('Predicted label')


def showmatrix(y_true, y_pred, labels, titlename):
    tick_marks = np.array(range(len(labels))) + 0.5
    plt.rcParams['font.sans-serif'] = ['SimHei']
    plt.rcParams['axes.unicode_minus'] = False
    cm = confusion_matrix(y_true, y_pred)
    np.set_printoptions(precision=2)
    cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    # 混淆矩阵
    print(cm_normalized)
    dd = np.zeros(len(cm_normalized))
    for i in range(len(dd)):
        dd[i] = cm_normalized[i, i]
    print(dd)
    max_c = cm_normalized.max()
    min_c = cm_normalized.min()
    print(sum(dd)/len(dd))
    np.save('result/matrixdata.npy', dd)
    plt.figure(figsize=(12, 8), dpi=120)
    ind_array = np.arange(len(labels))
    x, y = np.meshgrid(ind_array, ind_array)
    for x_val, y_val in zip(x.flatten(), y.flatten()):
        c = cm_normalized[y_val][x_val]
        if c > 0.01:
            if c > 0.5*(max_c+min_c):
                plt.text(x_val, y_val, "%0.2f" % (c,), color='white', fontsize=10, va='center', ha='center')
            else:
                plt.text(x_val, y_val, "%0.2f" % (c,), color='black', fontsize=10, va='center', ha='center')
    plt.gca().set_xticks(tick_marks, minor=True)
    plt.gca().set_yticks(tick_marks, minor=True)
    plt.gca().xaxis.set_ticks_position('none')
    plt.gca().yaxis.set_ticks_position('none')
    plt.grid(True, which='minor', linestyle='-')
    plt.gcf().subplots_adjust(bottom=0.15)
    plot_confusion_matrix(cm=cm_normalized, title=titlename, labels=labels)
    plt.savefig('result/confusionmatrix.svg', format='svg')
    plt.show()

然后在需要绘制的地方

showmatrix(val_label.cpu(), y_pred.cpu(), ['1', '2'], 'Confusion matrix')

字典数据读取/分割数据集

数据分割,可以根据输入的proportion维度自定义分割组数

def mysplit_n(dataset: Dict[str, np.ndarray], proportion: Tuple[int, ...]) -> List[Any]:
    out_list = []
    for n in range(len(proportion)):
        exec('dataset_'+str(n)+' = {}')
        exec('index_'+str(n)+' = np.arange(sum(proportion[:n]), sum(proportion[:n+1]))')
        exec('out_list.append(dataset_' + str(n)+')')
    for name in dataset:
        for n in range(len(proportion)):
            exec('dataset_' + str(n) + '[name] = dataset[name][index_'+str(n)+']')
    return out_list

先读取数据:

dataset = shuffle(dict(np.load(f'./datasets/'+'RML2016_'+str(args.SNR)+'dB_normalize_power.npz', allow_pickle=True)))

再对数据进行划分

dataset = dict(zip(['training', 'unlabeled', 'valid', 'test'], mysplit_n(dataset,
                                                                         (args.num_labeled,
                                                                          args.num_unlabeled,
                                                                          args.num_valid,
                                                                          args.num_test
                                                                          ))))

取出分割的数据

def make(dataset: Dict[str, np.ndarray]) -> Tuple[np.ndarray, np.ndarray]:
    xs, ys = [], []
    nclass = len(dataset.keys())
    for k, name in enumerate(dataset):
        xs.append(dataset[name])
        label = np.zeros((len(dataset[name]), nclass))
        label[:, k] = 1.0
        ys.append(label)
    xs = np.vstack(xs)
    ys = np.vstack(ys)
    index = np.arange(len(xs))
    np.random.shuffle(index)
    xs = xs[index]
    ys = ys[index]
    return xs, ys

labeled_dataset, labeled_label = make(dataset['training'])
unlabeled_dataset, unlabeled_label = make(dataset['unlabeled'])
valid_dataset, valid_label = make(dataset['valid'])
test_dataset, test_label = make(dataset['test'])

指定返回值类型

from typing import List, Tuple


def test(a: int, b: str) -> Tuple[int, str]:
    print(a, b)
    return 1000, '11'

自动给文章建好文件夹并创建.md

import os
import glob
import shutil
import time


def mymkdirs(path_dir):
    if os.path.exists(path_dir) is False:
        os.makedirs(path_dir)
        print(path_dir+'已创建')


def mypaperspace(path_paper):
    shutil.move(src=path_paper + '.pdf', dst=os.path.join(path_paper, 'paper.pdf'))
    with open(path_paper + '\\notes.md', 'a') as f:
        f.write('Date: ' + time.strftime('%Y.%m.%d  %H:%M', time.localtime(time.time())))
        f.write('\n')
        f.write('Author: Joffrey LC')
        f.write('\n')
        f.write('\n------------------------------------------------')


cwd_path = 'H:\学习\阅读\面向智能反射面数能系统的波形设计\\test'
mymkdirs(os.path.join(cwd_path, '文章备份'))
for filename in glob.glob(cwd_path+'/*.pdf'):
    shutil.copy(src=os.path.join(cwd_path, filename), dst=os.path.join(cwd_path, '文章备份'))
    path_dir = filename.split('.pdf')[0].split('\\')[-1]
    mymkdirs(os.path.join(cwd_path, path_dir))
    mypaperspace(os.path.join(cwd_path, path_dir))

# To do:
# shutil.move 移动文件名较长的文件时会发生错误,有时间的话尝试自己写一个移动文件的代码

python调制识别数据集读取

RML2016a

import os
import pickle

dataset = pickle.load(open("dataset/RML2016.10a_dict.pkl", "rb"), encoding='bytes')
save_path = 'dataset/split_dB_rml2016a'
if not os.path.exists(save_path):
    os.mkdir(save_path)
print('processing')

all_type: set = set()
all_snr: set = set()
for name in dataset.keys():
    type_name, snr = name[0], name[1]
    all_type.add(type_name)
    all_snr.add(snr)
print('调制形式为:', all_type)
list_snr = list(all_snr)
list_snr.sort()
print('信噪比为:', list_snr)

total = len(all_type)
for name in dataset.keys():
    type_name, snr = name[0], name[1]
    if int(snr) < 0:
        snr = 'b'+str(-snr)
    try:
        exec('dataset_'+str(snr)+'dB')
    except NameError:
        exec('dataset_'+str(snr)+'dB = {}')
    data = dataset[name]
    exec('dataset_'+str(snr)+'dB[type_name] = dataset[name]')
    if eval('len(dataset_'+str(snr)+'dB) == total'):
        with open(save_path+'/RML_2016_'+str(snr)+'dB.pkl', 'wb') as f:
            pickle.dump(eval('dataset_'+str(snr)+'dB'), f, pickle.HIGHEST_PROTOCOL)

RML2016c

import os
import pickle

dataset = pickle.load(open("dataset/2016.04C.multisnr.pkl", "rb"), encoding='bytes')
save_path = 'dataset/split_dB_rml2016c'
if not os.path.exists(save_path):
    os.mkdir(save_path)
print('processing')

all_type: set = set()
all_snr: set = set()
for name in dataset.keys():
    type_name, snr = name[0], name[1]
    all_type.add(type_name)
    all_snr.add(snr)
print('调制形式为:', all_type)
list_snr = list(all_snr)
list_snr.sort()
print('信噪比为:', list_snr)

total = len(all_type)
for name in dataset.keys():
    type_name, snr = name[0], name[1]
    if int(snr) < 0:
        snr = 'b'+str(-snr)
    try:
        exec('dataset_'+str(snr)+'dB')
    except NameError:
        exec('dataset_'+str(snr)+'dB = {}')
    data = dataset[name]
    exec('dataset_'+str(snr)+'dB[type_name] = dataset[name]')
    if eval('len(dataset_'+str(snr)+'dB) == total'):
        with open(save_path+'/RML_2016_'+str(snr)+'dB.pkl', 'wb') as f:
            pickle.dump(eval('dataset_'+str(snr)+'dB'), f, pickle.HIGHEST_PROTOCOL)

RML2018

import h5py
import pickle as pkl
import os
# file['X']是数据, file['Y']是one-hot 一类一类的  file['Z']是snr 一类的-20-30dB在一起

classes = ['OOK', '4ASK', '8ASK', 'BPSK', 'QPSK',
'8PSK', '16PSK', '32PSK', '16APSK', '32APSK', '64APSK',
'128APSK', '16QAM', '32QAM', '64QAM', '128QAM',
'256QAM', 'AM-SSB-WC', 'AM-SSB-SC', 'AM-DSB-WC',
'AM-DSB-SC', 'FM', 'GMSK', 'OQPSK']

save_path = 'dataset/split_dB_rml2018'
if not os.path.exists(save_path):
    os.mkdir(save_path)

need_snr = [20]

file = h5py.File('H:/DATASETS/RML/GOLD_XYZ_OSC.0001_1024.hdf5', 'r+')
for name in file.keys():
    print(name, file[name].shape)
num_per_class = int(file['Y'].shape[0]/file['Y'].shape[1])
print('每类所有dB信号总数:', num_per_class)
num_per_snr = 4096

data_list = file['X']
print(len(data_list))
# name_idx = [i for i in range(len(classes)) if classes[i] in ['BPSK', 'QPSK', '8PSK', '16QAM', '32QAM', '64QAM']]
name_idx = [i for i in range(len(classes)) if classes[i] in classes]

print(name_idx)

snr_list = file['Z'][:num_per_class]
# print(snr_list, len(snr_list))
idx_snr = []
for i in range(len(need_snr)):
    idx_part = [j for j in range(num_per_class) if snr_list[j] == need_snr[i]]
    idx_snr.append(idx_part)

for snr_iidx in range(len(need_snr)):
    exec('dict_out_'+str(need_snr[snr_iidx])+'dB = {}')
    for idx in range(len(name_idx)):
        start = name_idx[idx]*num_per_class
        end = (name_idx[idx]+1)*num_per_class
        out_one_class_all_snr = data_list[start:end]
        exec('dict_out_'+str(need_snr[snr_iidx])+'dB[classes[name_idx[idx]]] = out_one_class_all_snr[idx_snr[snr_iidx]]')
    with open(save_path+'\RML_2018_' + str(need_snr[snr_iidx])+'dB.pkl', 'wb') as f:
        pkl.dump(eval('dict_out_' + str(need_snr[snr_iidx]) + 'dB'), f, pkl.HIGHEST_PROTOCOL)

重新更新了一下,速度更快

import h5py
import numpy as np
import pickle
classes = ['OOK', '4ASK', '8ASK', 'BPSK', 'QPSK',
'8PSK', '16PSK', '32PSK', '16APSK', '32APSK', '64APSK',
'128APSK', '16QAM', '32QAM', '64QAM', '128QAM',
'256QAM', 'AM-SSB-WC', 'AM-SSB-SC', 'AM-DSB-WC',
'AM-DSB-SC', 'FM', 'GMSK', 'OQPSK']

file = h5py.File('H:/DATASETS/RML/GOLD_XYZ_OSC.0001_1024.hdf5', 'r+')
# for name in file.keys():
#     print(name, file[name].shape)
# num_per_class = int(file['Y'].shape[0]/file['Y'].shape[1])
# print('每类所有dB信号总数:', num_per_class)
# num_per_snr = 4096

# file['X']是数据, file['Y']是one-hot 一类一类的  file['Z']是snr 一类的-20-30dB在一起

data_all = file['X']
label_all = np.array(file['Y']).argmax(axis=1)
snr_all = file['Z']


need_snr = 20


need_idx = [idx for idx in range(int(len(label_all)/24)) if snr_all[idx] == need_snr]

need_idx_all = []
for i in range(24):
    need_idx_all.append([a+int(i*len(label_all)/24) for a in need_idx])

# for i in range(24):
#     exec('snr_'+str(i)+' = snr_all[need_idx_all[i]]')

data_out = {}
for i in range(24):
    data_out[classes[i]] = data_all[need_idx_all[i], :, :]

with open('datasets/dict_out_' + str(need_snr) + 'dB.pkl', 'wb') as f:
    pickle.dump(data_out, f, pickle.HIGHEST_PROTOCOL)
print('11')

load data

import numpy as np
import pickle as pkl


print('\nloading 2016c ... ')
# 读取2016c
with open('dataset/split_dB_rml2016c/RML_2016_0dB.pkl', 'rb') as f:
    data_2016c = pkl.load(f)
for k in data_2016c.keys():
    print(k, data_2016c[k].shape)

print('\nloading 2016a ... ')
# 读取2016a
with open('dataset/split_dB_rml2016a/RML_2016_0dB.pkl', 'rb') as f:
    data_2016a = pkl.load(f)
for k in data_2016a.keys():
    print(k, data_2016a[k].shape)


print('\nloading 2018 ... ')
# 读取2016a
with open('dataset/split_dB_rml2018/RML_2018_20dB.pkl', 'rb') as f:
    data_2018 = pkl.load(f)
for k in data_2018.keys():
    print(k, data_2018[k].shape)


print('\nloading WIDEFT-A2197_1 ... ')
data_WIDEFT = np.load('dataset/A2197_1/1.npy')
print(data_WIDEFT.real.shape, data_WIDEFT.imag.shape)

本文作者: Joffrey-Luo Cheng
本文链接: http://lcjoffrey.top/2021/12/04/pythonutils/
版权声明: 本博客所有文章除特别声明外,均采用 CC BY-SA 4.0 协议 ,转载请注明出处!