WHCSRL 技术网

使用stylegan2训练自己的数据集

官方开源链接

链接: link.

对数据集进行处理

数据集resize图片尺寸

// resize image
from glob import glob
from PIL import Image
import os
from tqdm import tqdm
from tqdm._tqdm import trange
img_path = glob("./resize/*.png")
path_save = "./resize/"
a = range(0, len(img_path))
i = 0
for file in tqdm(img_path):
    name = os.path.join(path_save, "%%d.png" %% a[i])
    im = Image.open(file)
    im.thumbnail((1024, 1024))
    print(im.format, im.size, im.mode)
    im.save(name, 'png')
    i += 1


  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19

生成数据集对应的tfrecords格式

// 第一个目录参数为tfrecords格式存放的目录,第二个目录参数为resize后images图片路径
python dataset_tool.py create_from_images ~/datasets/my-custom-dataset ~/my-custom-images
//可视化数据集
python dataset_tool.py display ~/datasets/my-custom-dataset
  • 1
  • 2
  • 3
  • 4

训练

// config文件分为f和e,对应用不同的显存大小训练
python run_training.py --num-gpus=1 --data-dir=datasets --config=config-e --dataset=custome_dataset1 --mirror-augment=true
  • 1
  • 2

测试

// seeds为生成的照片索引,可以取多个值
# Generate 1000 random images without truncation
python run_generator.py generate-images --seeds=0-999 --truncation-psi=1.0  --network=results/00006-stylegan2-ffhq-8gpu-config-f/networks-final.pkl
#example
python run_generator.py generate-images --seeds=9,66,286 --truncation-psi=1.0 --network=results/00007-stylegan2-custome_dataset-1gpu-config-e/network-snapshot-001323.pkl

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
推荐阅读