数据集007:垃圾分类数据集(含数据集下载链接)
数据集简介
本数据拥有
训练集:43685张;
验证集:5363张;
测试集:5363张;
总类别数:158类。
部分代码:
定义数据集
class MyDataset(Dataset):
def __init__(self, mode='train', transform=None):
super(MyDataset, self).__init__()
self.data = []
self.transform = transform
with open(f'{data_path}{mode}.txt') as f:
for line in f.readlines():
info = line.strip().split(' ')
if len(info) > 0:
self.data.append(
[data_path+'/'+info[0].strip(), info[1].strip()])
def __getitem__(self, idx):
image_file, label = self.data[idx]
img = Image.open(image_file).convert('RGB')
img = np.array(img)
# (Tensor(shape=[3, 227, 227], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
if self.transform is not None:
img = self.transform(img)
label = np.array([label], dtype="int64")
return img, label
def __len__(self):
定义ResNet网络
resnet50 = paddle.vision.models.resnet50(num_classes=158)
取单张测试图片进行可视化展示
import pylab as pl
import matplotlib.font_manager as fm
test_path = '/home/aistudio/Mydata/test1.txt'
myfont = fm.FontProperties(fname=r'/home/aistudio/simkai.ttf') # 设置字体
jetson_path = '/home/aistudio/Mydata/garbage_classification.json'
with open(jetson_path, 'r') as f1:
load_dict = json.load(f1)
with open(test_path, 'r') as f2:
img_path = f2.readline().strip().split(' ')
test_img_path = '/home/aistudio/Mydata/' + f'{img_path[0]}'
print('输入测试图片路径为:')
print(test_img_path)
clas = load_dict[f'{lab1}']#从字典中查找标签0对应的垃圾种类
img = cv2.imread(test_img_path)
plt.imshow(img)
plt.title(f'预测:{clas}', fontproperties = myfont, fontsize=20)