[DL]Efficient Graph-Based Image Segmentation
import os import math import torch import matplotlib.pyplot as plt from PIL import Image from torchvision import transforms from torchvision.transforms import ToPILImage MAX_VAL = 65535 means = [0.485, 0.456, 0.406] stds = [ 0.229, 0.224, 0.225] def cost_edge(a, b): ret = 0 tmp = 0 for i in range(3): tmp = a[i] - b[i] ret += tmp * tmp ret = math.sqrt( ret ) return ret def image_to_graph(image): c,h,w = image.shape vects = h * w edges = [] idx = 0 for row in range(h): for col in range(w): edges_4 = [] # 1 dst_row = row + 1 dst_col = col - 1 dst_vec = dst_row * w + dst_col cost = MAX_VAL if dst_row < h and dst_col > 0: cost = cost_edge( image[:, row, col], image[:, dst_row, dst_col] ) edges_4.append( cost ) # 2 dst_row = row + 1 dst_col = col dst_vec = dst_row * w + dst_col cost = MAX_VAL if dst_row < h: cost = cost_edge( image[:, row, col], image[:, dst_row, dst_col] ) edges_4.append( cost ) # 3 dst_row = row + 1 dst_col = col + 1 dst_vec = dst_row * w + dst_col cost = MAX_VAL if dst_row < h and dst_col < w: cost = cost_edge( image[:, row, col], image[:, dst_row, dst_col] ) edges_4.append( cost ) # 4 dst_row = row dst_col = col + 1 dst_vec = dst_row * w + dst_col cost = MAX_VAL if dst_col < w: cost = cost_edge( image[:, row, col], image[:, dst_row, dst_col] ) edges_4.append( cost ) edges.append(edges_4) return edges transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(means, stds) ]) class Union_Set: def __init__(self, vertec_num, k=0.05): self.c = [ -1 for i in range(vertec_num) ] self.root = [ -1 for i in range(vertec_num) ] self.root_Num = 0 self.c_Int = [ k for i in range(vertec_num) ] self.c_Size = [ 1 for i in range(vertec_num) ] self.c_Color = [-1 for i in range(vertec_num) ] self.vertec_num = vertec_num def find_root(self, idx): root = self.c[idx] while( root > -1 ): idx = root root = self.c[idx] return idx def merge_set(self, ci, cj): self.c[cj] = ci def display(self): print(self.c) print(self.root) def init_root(self): self.root_Num = 0 for idx in range(self.vertec_num): if self.c[idx] < 0: self.root_Num += 1 for idx in range(self.vertec_num): if self.root[idx] > -1: continue tmp = [] root = self.c[idx] # get parent node of the set tmp.append(idx) while( root > -1 ): # if one of his parent has been seted , give up now. if self.root[root] > -1: idx = self.root[root] break tmp.append(root) idx = root root = self.c[idx] # set root of parent node for tmp_ in tmp: self.root[tmp_] = idx def init_color(self): step = 255 / self.root_Num color = 0 for root_idx in range(self.root_Num): cur_root = -1 color += step for idx in range(self.vertec_num): if self.c_Color[idx] >= 0: continue if cur_root < 0: cur_root = self.root[idx] if cur_root == self.root[idx]: self.c_Color[idx] = color tensor_to_pil = ToPILImage() def display_img(img1, img2): plt.subplot(1, 2, 1) plt.imshow(img1) plt.subplot(1, 2, 2) plt.imshow(img2) plt.show() class graph: def __init__(self, img): # get vertec number _c,_h,_w = img.shape self.c = _c self.h = _h self.w = _w self.vertec_num = _h * _w # transform the image to graph self.edges = image_to_graph(img) self.edges = torch.tensor(self.edges) self.edges = self.edges.view( [self.vertec_num * 4] ) self.edges_num = len(self.edges) # sorted edges _, _idx_sorted = torch.sort( self.edges, descending=False ) self.idx_sorted = _idx_sorted def neib_idx(self, my, idx): my_h = math.floor(my / self.w) my_w = my % self.w if idx == 3: my_w += 1 else: my_h += 1 my_w += idx - 1 return my_h * self.w + my_w def get_edge_vertec(self, edge_idx): vi = math.floor( edge_idx / 4) vj = self.neib_idx(vi, edge_idx % 4) return vi,vj def segment_img(graph, k=0.05): c = Union_Set(graph.vertec_num, k) for q in range(graph.edges_num): edge_idx = graph.idx_sorted[q] weight = graph.edges[edge_idx] vi,vj = graph.get_edge_vertec( edge_idx ) if vj >= graph.vertec_num: continue ci = c.find_root(vi) cj = c.find_root(vj) Int_ci = c.c_Int[ci] # Int(c,ci) Int_cj = c.c_Int[cj] # Int(c,cj) if ci != cj and weight <= min( Int_ci, Int_cj ) : c.merge_set(ci, cj) c.c_Size[ci] += c.c_Size[cj] c.c_Size[cj] = 0 c.c_Int[ci] = weight + k/c.c_Size[ci] c.init_root() #c.display() c.init_color() ''' tt = c.c_Color tt = torch.tensor( tt ) tt = tt.reshape( ( graph.h, graph.w) ) print(tt[:,1:12]) ''' print(" segments: ", c.root_Num ) mask = torch.tensor( c.c_Color, dtype=torch.uint8) mask = mask.reshape( (-1, graph.h, graph.w) ) #mask = torch.floor(mask) del(c) return mask # open image img_raw = Image.open("test1.png") img = img_raw if len(img.getbands()) == 1: img = img.convert("RGB") img = transform(img) k = 0.001 g = graph(img) print( g.h, g.w) k = 0.1 col = 5 row = 2 size = col * row plt.subplot(row, col, 1) plt.title("raw") plt.imshow( img_raw ) #plt.ion() for i in range(size - 1): mask = segment_img(g, k) img = tensor_to_pil( mask ) #img.save( "out_%f.png"%(k) ) plt.subplot(row, col, i+2) plt.title( k ) plt.imshow( img ) k += 300 plt.show()
日历
最新微语
- 有的时候,会站在分叉路口,不知道向左还是右
2023-12-26 15:34
- 繁花乱开,鸟雀逐风。心自宁静,纷扰不闻。
2023-03-14 09:56
- 对于不可控的事,我们保持乐观,对于可控的事情,我们保持谨慎。
2023-02-09 11:03
- 小时候,
暑假意味着无忧无虑地玩很长一段时间,
节假意味着好吃好喝还有很多长期不见的小朋友来玩...
长大后,
这是女儿第一个暑假,
一个半月...
2022-07-11 08:54
- Watching the autumn leaves falling as you grow older together
2018-10-25 09:45
分类
最新评论
- Goonog
i get it now :) - 萧
@Fluzak:The web host... - Fluzak
Nice blog here! Also... - Albertarive
In my opinion you co... - ChesterHep
What does it plan? - ChesterHep
No, opposite. - mojoheadz
Everything is OK!... - Josephmaigh
I just want to say t... - ChesterHep
What good topic - AnthonyBub
Certainly, never it ...
发表评论: