[DL]Efficient Graph-Based Image Segmentation

2022-6-29 写技术

import os
import math
import torch
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms
from torchvision.transforms import ToPILImage


MAX_VAL = 65535

means = [0.485, 0.456, 0.406]
stds = [ 0.229, 0.224, 0.225]

def cost_edge(a, b):
    ret = 0
    tmp = 0
    for i in range(3):
        tmp =  a[i] - b[i]
        ret += tmp * tmp
    ret = math.sqrt( ret )

    return ret

def image_to_graph(image):
    c,h,w = image.shape

    vects = h * w
    edges = []
    idx = 0
    
    for row in range(h):
        for col in range(w):
            edges_4 = []

            # 1
            dst_row = row + 1
            dst_col = col - 1
            dst_vec = dst_row * w + dst_col
            cost = MAX_VAL
            if dst_row < h and dst_col > 0:
                cost = cost_edge( image[:, row, col], image[:, dst_row, dst_col] )
            edges_4.append( cost ) 

            # 2
            dst_row = row + 1
            dst_col = col
            dst_vec = dst_row * w + dst_col
            cost = MAX_VAL
            if dst_row < h:
                cost = cost_edge( image[:, row, col], image[:, dst_row, dst_col] )
            edges_4.append( cost ) 

            # 3
            dst_row = row + 1
            dst_col = col + 1
            dst_vec = dst_row * w + dst_col
            cost = MAX_VAL
            if dst_row < h and dst_col < w:
                cost = cost_edge( image[:, row, col], image[:, dst_row, dst_col] )
            edges_4.append( cost ) 

            # 4
            dst_row = row
            dst_col = col + 1
            dst_vec = dst_row * w + dst_col
            cost = MAX_VAL
            if dst_col < w:
                cost = cost_edge( image[:, row, col], image[:, dst_row, dst_col] )
            edges_4.append( cost ) 

            edges.append(edges_4)

    return edges

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(means, stds)
    ])


class Union_Set:
    def __init__(self, vertec_num, k=0.05):
        self.c = [ -1 for i in range(vertec_num) ]
        self.root = [ -1 for i in range(vertec_num) ]
        self.root_Num = 0
        self.c_Int = [ k for i in range(vertec_num) ]
        self.c_Size = [ 1 for i in range(vertec_num) ]
        self.c_Color = [-1 for i in range(vertec_num) ]
        self.vertec_num = vertec_num

    def find_root(self, idx):
        root = self.c[idx]
        while( root > -1 ):
            idx = root
            root = self.c[idx]
        return idx

    def merge_set(self, ci, cj):
        self.c[cj] = ci

    def display(self):
        print(self.c)
        print(self.root)

    def init_root(self):
        self.root_Num = 0
        for idx in range(self.vertec_num):
            if self.c[idx] < 0:
                self.root_Num += 1


        for idx in range(self.vertec_num):
            if self.root[idx] > -1:
                continue

            tmp = []
            root = self.c[idx]

            # get parent node of the set
            tmp.append(idx)
            while( root > -1 ):
                # if one of his parent has been seted , give up now.
                if self.root[root] > -1:
                    idx = self.root[root]
                    break

                tmp.append(root)

                idx = root
                root = self.c[idx]

            # set root of parent node
            for tmp_ in tmp:
                self.root[tmp_] = idx

    def init_color(self):
        step = 255 / self.root_Num
        color = 0
        for root_idx in range(self.root_Num):
            cur_root = -1
            color += step
            for idx in range(self.vertec_num):
                if self.c_Color[idx] >= 0:
                    continue

                if cur_root < 0:
                    cur_root =  self.root[idx]
           
                if cur_root == self.root[idx]:
                    self.c_Color[idx] = color




tensor_to_pil = ToPILImage()

def display_img(img1, img2):
    plt.subplot(1, 2, 1)
    plt.imshow(img1)

    plt.subplot(1, 2, 2)
    plt.imshow(img2)

    plt.show()


class graph:
    def __init__(self, img):
        # get vertec number
        _c,_h,_w = img.shape
        self.c = _c
        self.h = _h
        self.w = _w
        self.vertec_num = _h * _w

        # transform the image to graph
        self.edges = image_to_graph(img)
        self.edges = torch.tensor(self.edges)

        self.edges = self.edges.view( [self.vertec_num * 4] )
        self.edges_num = len(self.edges)

        # sorted edges
        _, _idx_sorted = torch.sort( self.edges, descending=False )
        self.idx_sorted = _idx_sorted

    def neib_idx(self, my, idx):
        my_h = math.floor(my / self.w)
        my_w = my % self.w

        if idx == 3:
            my_w += 1
        else:
            my_h += 1
            my_w += idx - 1

        return my_h * self.w + my_w

    def get_edge_vertec(self, edge_idx):
        vi = math.floor( edge_idx / 4)
        vj = self.neib_idx(vi, edge_idx % 4)
        return vi,vj



def segment_img(graph, k=0.05):
    c = Union_Set(graph.vertec_num, k)
    for q in range(graph.edges_num):
        edge_idx = graph.idx_sorted[q]
        weight = graph.edges[edge_idx]

        vi,vj = graph.get_edge_vertec( edge_idx )
        if vj >= graph.vertec_num:
            continue
    
        ci = c.find_root(vi)
        cj = c.find_root(vj)

        Int_ci = c.c_Int[ci]     #    Int(c,ci)
        Int_cj = c.c_Int[cj]     #    Int(c,cj)
        if ci != cj and weight <= min( Int_ci, Int_cj ) :
            c.merge_set(ci, cj)
            c.c_Size[ci] += c.c_Size[cj]
            c.c_Size[cj] = 0
            c.c_Int[ci] = weight + k/c.c_Size[ci]

    c.init_root()
    #c.display()
    c.init_color()

    '''
    tt = c.c_Color
    tt = torch.tensor( tt )
    tt = tt.reshape( ( graph.h, graph.w) )
    print(tt[:,1:12])
    '''


    print(" segments: ", c.root_Num )

    mask = torch.tensor( c.c_Color, dtype=torch.uint8)
    mask = mask.reshape( (-1, graph.h, graph.w) )
    #mask = torch.floor(mask)

    del(c)
    return mask


# open image
img_raw = Image.open("test1.png")
img = img_raw
if len(img.getbands()) == 1:
    img = img.convert("RGB")
img = transform(img)

k = 0.001
g = graph(img)
print( g.h, g.w)

k = 0.1
col = 5
row = 2
size = col * row


plt.subplot(row, col, 1)
plt.title("raw")
plt.imshow( img_raw )


#plt.ion()
for i in range(size - 1):
    mask = segment_img(g, k)
    img = tensor_to_pil( mask )
    #img.save( "out_%f.png"%(k) )

    plt.subplot(row, col, i+2)
    plt.title( k )
    plt.imshow( img )

    k += 300

plt.show()

标签: neural network linux python machine_learning DL

发表评论:

Powered by anycle 湘ICP备15001973号-1