fer2013.py 1.83 KB
Newer Older
1 2 3
import os
import random

4
import cv2
5
import numpy as np
6 7 8 9
import pandas as pd

curdir = os.path.abspath(os.path.dirname(__file__))

10 11 12 13

def gen_record(csvfile, channel):
    data = pd.read_csv(csvfile, delimiter=',', dtype='a')
    labels = np.array(data['emotion'], np.float)
14
    # print(labels,'\n',data['emotion'])
15

16
    imagebuffer = np.array(data['pixels'])
17
    images = np.array([np.fromstring(image, np.uint8, sep=' ') for image in imagebuffer])
18 19
    del imagebuffer
    num_shape = int(np.sqrt(images.shape[-1]))
20
    images.shape = (images.shape[0], num_shape, num_shape)
21 22 23 24 25
    # img=images[0];cv2.imshow('test',img);cv2.waitKey(0);cv2.destroyAllWindow();exit()
    dirs = set(data['Usage'])
    subdirs = set(labels)
    class_dir = {}
    for dr in dirs:
26
        dest = os.path.join(curdir, dr)
27 28 29
        class_dir[dr] = dest
        if not os.path.exists(dest):
            os.mkdir(dest)
30 31 32

    data = zip(labels, images, data['Usage'])

33
    for d in data:
34
        destdir = os.path.join(class_dir[d[-1]], str(int(d[0])))
35 36 37
        if not os.path.exists(destdir):
            os.mkdir(destdir)
        img = d[1]
38
        filepath = unique_name(destdir, d[-1])
39 40 41
        print('[^_^] Write image to %s' % filepath)
        if not filepath:
            continue
42
        sig = cv2.imwrite(filepath, img)
43 44 45 46 47
        if not sig:
            print('Error')
            exit(-1)


48 49 50
def unique_name(pardir, prefix, suffix='jpg'):
    filename = '{0}_{1}.{2}'.format(prefix, random.randint(1, 10 ** 8), suffix)
    filepath = os.path.join(pardir, filename)
51 52
    if not os.path.exists(filepath):
        return filepath
53
    unique_name(pardir, prefix, suffix)
54 55 56 57


if __name__ == '__main__':
    filename = 'fer2013.csv'
58 59 60
    filename = os.path.join(curdir, filename)
    gen_record(filename, 1)

61 62
    # ##################### test
    # tmp = unique_name('./Training','Training')
63
    # print(tmp)