プリミティブは球体のみ、シェーディングはフラットシェーディング、カメラは正投影、解像度固定。
import math
from itertools import product
from functools import partial
def range2(rx, ry):
return product(range(rx), range(ry))
class Vector(object):
def __init__(self, x, y, z):
self.x = x
self.y = y
self.z = z
def normalize(self):
v = math.sqrt(self.x*self.x + self.y*self.y + self.z*self.z)
self.x /= v
self.y /= v
self.z /= v
def __mul__(self, v):
return self.x*v.x + self.y*v.y + self.z*v.z
def __sub__(self, v):
return Vector(self.x - v.x, self.y - v.y, self.z - v.z)
def printval(self):
print self.x, self.y, self.z
class Color(Vector):
def __init__(self, r, g, b):
super(Color, self).__init__(r, g, b)
def printval(self):
print '%d %d %d' % (self.x*255, self.y*255, self.z*255)
class Point(Vector):
def __init__(self, x, y, z):
super(Point, self).__init__(x, y, z)
class Ray(object):
def __init__(self, p, v):
self.p = p
self.v = v
self.v.normalize()
class Primitive(object):
def __init__(self):
pass
def intersect(self, ray):
return (None, None)
def color(self):
pass
class Back(Primitive):
def __init__(self):
super(Back, self).__init__()
def intersect(self, ray):
return (self, float('inf'))
def color(self):
return Color(0.5, 0.5, 0.5)
class Sphere(Primitive):
def __init__(self, p, rad, col):
self.cp = p
self.rad = rad
self.col = col
def _pmin(self, a, b):
if b < 0.0:
return None
if a < 0.0:
return b
return a
def intersect(self, ray):
v = ray.p - self.cp
n = v*ray.v
det = n*n - v*v + self.rad*self.rad
if det < 0.0:
return (self, None)
d = math.sqrt(det)
t = self._pmin(-n - d, -n + d)
return (self, t)
def color(self):
return self.col
def create_ray():
rx = 200
ry = 200
f = lambda r, x: 1.0 - 2.0/r*x
v = Vector(0.0, 0.0, -1.0)
rs = [Ray(Point(f(rx, x), f(ry, y), 2.0), v) for x, y in range2(rx, ry)]
return rs
def get_objs():
return [Sphere(Point(0.0, 0.0, 0.0), 0.5, Color(1.0, 0.0, 0.0)),
Sphere(Point(0.5, 0.5, -1.0), 0.5, Color(0.0, 1.0, 0.0)),
Back()]
def trace(ray):
objs = get_objs()
func = lambda ray, obj: obj.intersect(ray)
ts = map(partial(func, ray), objs)
ts = filter(lambda x: x[1], ts)
ts.sort(key=lambda x: x[1])
return ts[0][0].color()
def output(cs):
print 'P3\n200 200\n255'
map(lambda c: c.printval(), cs)
def main():
rs = create_ray()
cs = map(trace, rs)
output(cs)
main()