import os
import math
import torch
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms
from torchvision.transforms import ToPILImage
MAX_VAL = 65535
means = [0.485, 0.456, 0.406]
stds = [ 0.229, 0.224, 0.225]
def cost_edge(a, b):
ret = 0
tmp = 0
for i in range(3):
tmp = a[i] - b[i]
ret += tmp * tmp
ret = math.sqrt( ret )
return ret
def image_to_graph(image):
c,h,w = image.shape
vects = h * w
edges = []
idx = 0
for row in range(h):
for col in range(w):
edges_4 = []
# 1
dst_row = row + 1
dst_col = col - 1
dst_vec = dst_row * w + dst_col
cost = MAX_VAL
if dst_row < h and dst_col > 0:
cost = cost_edge( image[:, row, col], image[:, dst_row, dst_col] )
edges_4.append( cost )
# 2
dst_row = row + 1
dst_col = col
dst_vec = dst_row * w + dst_col
cost = MAX_VAL
if dst_row < h:
cost = cost_edge( image[:, row, col], image[:, dst_row, dst_col] )
edges_4.append( cost )
# 3
dst_row = row + 1
dst_col = col + 1
dst_vec = dst_row * w + dst_col
cost = MAX_VAL
if dst_row < h and dst_col < w:
cost = cost_edge( image[:, row, col], image[:, dst_row, dst_col] )
edges_4.append( cost )
# 4
dst_row = row
dst_col = col + 1
dst_vec = dst_row * w + dst_col
cost = MAX_VAL
if dst_col < w:
cost = cost_edge( image[:, row, col], image[:, dst_row, dst_col] )
edges_4.append( cost )
edges.append(edges_4)
return edges
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(means, stds)
])
class Union_Set:
def __init__(self, vertec_num, k=0.05):
self.c = [ -1 for i in range(vertec_num) ]
self.root = [ -1 for i in range(vertec_num) ]
self.root_Num = 0
self.c_Int = [ k for i in range(vertec_num) ]
self.c_Size = [ 1 for i in range(vertec_num) ]
self.c_Color = [-1 for i in range(vertec_num) ]
self.vertec_num = vertec_num
def find_root(self, idx):
root = self.c[idx]
while( root > -1 ):
idx = root
root = self.c[idx]
return idx
def merge_set(self, ci, cj):
self.c[cj] = ci
def display(self):
print(self.c)
print(self.root)
def init_root(self):
self.root_Num = 0
for idx in range(self.vertec_num):
if self.c[idx] < 0:
self.root_Num += 1
for idx in range(self.vertec_num):
if self.root[idx] > -1:
continue
tmp = []
root = self.c[idx]
# get parent node of the set
tmp.append(idx)
while( root > -1 ):
# if one of his parent has been seted , give up now.
if self.root[root] > -1:
idx = self.root[root]
break
tmp.append(root)
idx = root
root = self.c[idx]
# set root of parent node
for tmp_ in tmp:
self.root[tmp_] = idx
def init_color(self):
step = 255 / self.root_Num
color = 0
for root_idx in range(self.root_Num):
cur_root = -1
color += step
for idx in range(self.vertec_num):
if self.c_Color[idx] >= 0:
continue
if cur_root < 0:
cur_root = self.root[idx]
if cur_root == self.root[idx]:
self.c_Color[idx] = color
tensor_to_pil = ToPILImage()
def display_img(img1, img2):
plt.subplot(1, 2, 1)
plt.imshow(img1)
plt.subplot(1, 2, 2)
plt.imshow(img2)
plt.show()
class graph:
def __init__(self, img):
# get vertec number
_c,_h,_w = img.shape
self.c = _c
self.h = _h
self.w = _w
self.vertec_num = _h * _w
# transform the image to graph
self.edges = image_to_graph(img)
self.edges = torch.tensor(self.edges)
self.edges = self.edges.view( [self.vertec_num * 4] )
self.edges_num = len(self.edges)
# sorted edges
_, _idx_sorted = torch.sort( self.edges, descending=False )
self.idx_sorted = _idx_sorted
def neib_idx(self, my, idx):
my_h = math.floor(my / self.w)
my_w = my % self.w
if idx == 3:
my_w += 1
else:
my_h += 1
my_w += idx - 1
return my_h * self.w + my_w
def get_edge_vertec(self, edge_idx):
vi = math.floor( edge_idx / 4)
vj = self.neib_idx(vi, edge_idx % 4)
return vi,vj
def segment_img(graph, k=0.05):
c = Union_Set(graph.vertec_num, k)
for q in range(graph.edges_num):
edge_idx = graph.idx_sorted[q]
weight = graph.edges[edge_idx]
vi,vj = graph.get_edge_vertec( edge_idx )
if vj >= graph.vertec_num:
continue
ci = c.find_root(vi)
cj = c.find_root(vj)
Int_ci = c.c_Int[ci] # Int(c,ci)
Int_cj = c.c_Int[cj] # Int(c,cj)
if ci != cj and weight <= min( Int_ci, Int_cj ) :
c.merge_set(ci, cj)
c.c_Size[ci] += c.c_Size[cj]
c.c_Size[cj] = 0
c.c_Int[ci] = weight + k/c.c_Size[ci]
c.init_root()
#c.display()
c.init_color()
'''
tt = c.c_Color
tt = torch.tensor( tt )
tt = tt.reshape( ( graph.h, graph.w) )
print(tt[:,1:12])
'''
print(" segments: ", c.root_Num )
mask = torch.tensor( c.c_Color, dtype=torch.uint8)
mask = mask.reshape( (-1, graph.h, graph.w) )
#mask = torch.floor(mask)
del(c)
return mask
# open image
img_raw = Image.open("test1.png")
img = img_raw
if len(img.getbands()) == 1:
img = img.convert("RGB")
img = transform(img)
k = 0.001
g = graph(img)
print( g.h, g.w)
k = 0.1
col = 5
row = 2
size = col * row
plt.subplot(row, col, 1)
plt.title("raw")
plt.imshow( img_raw )
#plt.ion()
for i in range(size - 1):
mask = segment_img(g, k)
img = tensor_to_pil( mask )
#img.save( "out_%f.png"%(k) )
plt.subplot(row, col, i+2)
plt.title( k )
plt.imshow( img )
k += 300
plt.show()